@@ -1101,6 +1101,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
11011101 "from the ShareGPT dataset." ,
11021102 )
11031103
1104+ blazedit_group = parser .add_argument_group ("blazedit dataset options" )
1105+ blazedit_group .add_argument (
1106+ "--blazedit-min-distance" ,
1107+ type = float ,
1108+ default = 0.0 ,
1109+ help =
1110+ "Minimum distance for blazedit dataset. Min: 0, Max: 1.0" ,
1111+ )
1112+ blazedit_group .add_argument (
1113+ "--blazedit-max-distance" ,
1114+ type = float ,
1115+ default = 1.0 ,
1116+ help =
1117+ "Maximum distance for blazedit dataset. Min: 0, Max: 1.0" ,
1118+ )
1119+
11041120 random_group = parser .add_argument_group ("random dataset options" )
11051121 random_group .add_argument (
11061122 "--random-input-len" ,
@@ -1333,6 +1349,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
13331349 elif args .dataset_name == "hf" :
13341350 # all following datasets are implemented from the
13351351 # HuggingFaceDataset base class
1352+ hf_kwargs = {}
13361353 if (
13371354 args .dataset_path in VisionArenaDataset .SUPPORTED_DATASET_PATHS
13381355 or args .hf_name in VisionArenaDataset .SUPPORTED_DATASET_PATHS
@@ -1376,6 +1393,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
13761393 ):
13771394 dataset_class = ASRDataset
13781395 args .hf_split = "train"
1396+ elif args .dataset_path in BlazeditDataset .SUPPORTED_DATASET_PATHS :
1397+ dataset_class = BlazeditDataset
1398+ args .hf_split = "train"
1399+ hf_kwargs = {
1400+ "min_distance" : args .blazedit_min_distance ,
1401+ "max_distance" : args .blazedit_max_distance ,
1402+ }
13791403 elif (
13801404 args .dataset_path in MLPerfDataset .SUPPORTED_DATASET_PATHS
13811405 or args .hf_name in MLPerfDataset .SUPPORTED_DATASET_PATHS
@@ -1415,6 +1439,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
14151439 tokenizer = tokenizer ,
14161440 output_len = args .hf_output_len ,
14171441 request_id_prefix = args .request_id_prefix ,
1442+ ** hf_kwargs
14181443 )
14191444
14201445 else :
@@ -2090,6 +2115,94 @@ def sample(
20902115 return sampled_requests
20912116
20922117
2118+ # -----------------------------------------------------------------------------
2119+ # Blazedit Dataset Implementation
2120+ # -----------------------------------------------------------------------------
2121+
2122+
2123+ class BlazeditDataset (HuggingFaceDataset ):
2124+ """
2125+ Blazedit Dataset.
2126+ https://github.com/ise-uiuc/blazedit
2127+
2128+ 5k char version: vdaita/edit_5k_char
2129+ 10k char version: vdaita/edit_10k_char
2130+ """ # noqa: E501
2131+
2132+ # 5k char version will have output as ~5k chars
2133+ # 10k char version will have output as ~10k chars
2134+ # Assuming 3 char per token, 10k chars will be 3333 tokens
2135+ # We set default to 4000 to be safe
2136+ DEFAULT_OUTPUT_LEN = 4000
2137+ SUPPORTED_DATASET_PATHS = {
2138+ "vdaita/edit_5k_char" ,
2139+ "vdaita/edit_10k_char" ,
2140+ }
2141+
2142+ def sample (
2143+ self ,
2144+ tokenizer : PreTrainedTokenizerBase ,
2145+ num_requests : int ,
2146+ output_len : Optional [int ] = None ,
2147+ request_id_prefix : str = "" ,
2148+ min_distance : float = 0.0 ,
2149+ max_distance : float = 1.0 ,
2150+ ** kwargs ,
2151+ ) -> list :
2152+ output_len = (output_len
2153+ if output_len is not None else self .DEFAULT_OUTPUT_LEN )
2154+ sampled_requests = []
2155+
2156+ for i , item in enumerate (self .data ):
2157+ if len (sampled_requests ) >= num_requests :
2158+ break
2159+ code = item ["code" ]
2160+ change_request = item ["change_request" ]
2161+ norm_distance = item ["norm_distance" ]
2162+
2163+ # compare the levenshtein distance normalized by code length
2164+ if norm_distance < min_distance or norm_distance > max_distance :
2165+ continue
2166+
2167+ # template copied from
2168+ # https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
2169+ instruction = f"""Given a code file, please apply the change requests and generate the new file.
2170+
2171+ Original file:
2172+ ```python
2173+ { code }
2174+ ```
2175+
2176+ Change request:
2177+ { change_request }
2178+
2179+ Please generate the new code file in the "New file" section below.""" # noqa: E501
2180+
2181+ # apply template
2182+ prompt = tokenizer .apply_chat_template (
2183+ [{
2184+ "role" : "user" ,
2185+ "content" : instruction
2186+ }],
2187+ add_generation_prompt = True ,
2188+ tokenize = False ,
2189+ )
2190+
2191+ prompt_len = len (tokenizer (prompt ).input_ids )
2192+
2193+ sampled_requests .append (
2194+ SampleRequest (
2195+ prompt = prompt ,
2196+ prompt_len = prompt_len ,
2197+ expected_output_len = output_len ,
2198+ request_id = request_id_prefix + str (i ),
2199+ ))
2200+ self .maybe_oversample_requests (sampled_requests , num_requests ,
2201+ request_id_prefix )
2202+
2203+ return sampled_requests
2204+
2205+
20932206# -----------------------------------------------------------------------------
20942207# AIMO Dataset Implementation
20952208# -----------------------------------------------------------------------------
0 commit comments