Skip to content

Commit 1f46c29

Browse files
[Spec Decode][Benchmark] Add Blitzedit dataset (vllm-project#23605)
Signed-off-by: Ekagra Ranjan <[email protected]> Co-authored-by: Roger Wang <[email protected]>
1 parent b091376 commit 1f46c29

File tree

1 file changed

+113
-0
lines changed

1 file changed

+113
-0
lines changed

datasets.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
11011101
"from the ShareGPT dataset.",
11021102
)
11031103

1104+
blazedit_group = parser.add_argument_group("blazedit dataset options")
1105+
blazedit_group.add_argument(
1106+
"--blazedit-min-distance",
1107+
type=float,
1108+
default=0.0,
1109+
help=
1110+
"Minimum distance for blazedit dataset. Min: 0, Max: 1.0",
1111+
)
1112+
blazedit_group.add_argument(
1113+
"--blazedit-max-distance",
1114+
type=float,
1115+
default=1.0,
1116+
help=
1117+
"Maximum distance for blazedit dataset. Min: 0, Max: 1.0",
1118+
)
1119+
11041120
random_group = parser.add_argument_group("random dataset options")
11051121
random_group.add_argument(
11061122
"--random-input-len",
@@ -1333,6 +1349,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
13331349
elif args.dataset_name == "hf":
13341350
# all following datasets are implemented from the
13351351
# HuggingFaceDataset base class
1352+
hf_kwargs = {}
13361353
if (
13371354
args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS
13381355
or args.hf_name in VisionArenaDataset.SUPPORTED_DATASET_PATHS
@@ -1376,6 +1393,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
13761393
):
13771394
dataset_class = ASRDataset
13781395
args.hf_split = "train"
1396+
elif args.dataset_path in BlazeditDataset.SUPPORTED_DATASET_PATHS:
1397+
dataset_class = BlazeditDataset
1398+
args.hf_split = "train"
1399+
hf_kwargs = {
1400+
"min_distance": args.blazedit_min_distance,
1401+
"max_distance": args.blazedit_max_distance,
1402+
}
13791403
elif (
13801404
args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS
13811405
or args.hf_name in MLPerfDataset.SUPPORTED_DATASET_PATHS
@@ -1415,6 +1439,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
14151439
tokenizer=tokenizer,
14161440
output_len=args.hf_output_len,
14171441
request_id_prefix=args.request_id_prefix,
1442+
**hf_kwargs
14181443
)
14191444

14201445
else:
@@ -2090,6 +2115,94 @@ def sample(
20902115
return sampled_requests
20912116

20922117

2118+
# -----------------------------------------------------------------------------
2119+
# Blazedit Dataset Implementation
2120+
# -----------------------------------------------------------------------------
2121+
2122+
2123+
class BlazeditDataset(HuggingFaceDataset):
2124+
"""
2125+
Blazedit Dataset.
2126+
https://github.com/ise-uiuc/blazedit
2127+
2128+
5k char version: vdaita/edit_5k_char
2129+
10k char version: vdaita/edit_10k_char
2130+
""" # noqa: E501
2131+
2132+
# 5k char version will have output as ~5k chars
2133+
# 10k char version will have output as ~10k chars
2134+
# Assuming 3 char per token, 10k chars will be 3333 tokens
2135+
# We set default to 4000 to be safe
2136+
DEFAULT_OUTPUT_LEN = 4000
2137+
SUPPORTED_DATASET_PATHS = {
2138+
"vdaita/edit_5k_char",
2139+
"vdaita/edit_10k_char",
2140+
}
2141+
2142+
def sample(
2143+
self,
2144+
tokenizer: PreTrainedTokenizerBase,
2145+
num_requests: int,
2146+
output_len: Optional[int] = None,
2147+
request_id_prefix: str = "",
2148+
min_distance: float = 0.0,
2149+
max_distance: float = 1.0,
2150+
**kwargs,
2151+
) -> list:
2152+
output_len = (output_len
2153+
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
2154+
sampled_requests = []
2155+
2156+
for i, item in enumerate(self.data):
2157+
if len(sampled_requests) >= num_requests:
2158+
break
2159+
code = item["code"]
2160+
change_request = item["change_request"]
2161+
norm_distance = item["norm_distance"]
2162+
2163+
# compare the levenshtein distance normalized by code length
2164+
if norm_distance < min_distance or norm_distance > max_distance:
2165+
continue
2166+
2167+
# template copied from
2168+
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
2169+
instruction = f"""Given a code file, please apply the change requests and generate the new file.
2170+
2171+
Original file:
2172+
```python
2173+
{code}
2174+
```
2175+
2176+
Change request:
2177+
{change_request}
2178+
2179+
Please generate the new code file in the "New file" section below.""" # noqa: E501
2180+
2181+
# apply template
2182+
prompt = tokenizer.apply_chat_template(
2183+
[{
2184+
"role": "user",
2185+
"content": instruction
2186+
}],
2187+
add_generation_prompt=True,
2188+
tokenize=False,
2189+
)
2190+
2191+
prompt_len = len(tokenizer(prompt).input_ids)
2192+
2193+
sampled_requests.append(
2194+
SampleRequest(
2195+
prompt=prompt,
2196+
prompt_len=prompt_len,
2197+
expected_output_len=output_len,
2198+
request_id=request_id_prefix + str(i),
2199+
))
2200+
self.maybe_oversample_requests(sampled_requests, num_requests,
2201+
request_id_prefix)
2202+
2203+
return sampled_requests
2204+
2205+
20932206
# -----------------------------------------------------------------------------
20942207
# AIMO Dataset Implementation
20952208
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)