@@ -1227,6 +1227,16 @@ def normalize(d: dict) -> dict[tuple[int, int, int], float]:
12271227 type = str ,
12281228 default = None ,
12291229 help = "Split of the HF dataset." )
1230+ hf_group .add_argument (
1231+ "--hf-name" ,
1232+ type = str ,
1233+ default = None ,
1234+ help = (
1235+ "Name of the dataset on HuggingFace "
1236+ "(e.g., 'lmarena-ai/VisionArena-Chat'). "
1237+ "Specify this if your dataset-path is a local path."
1238+ ),
1239+ )
12301240 hf_group .add_argument (
12311241 "--hf-output-len" ,
12321242 type = int ,
@@ -1307,28 +1317,53 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
13071317 elif args .dataset_name == "hf" :
13081318 # all following datasets are implemented from the
13091319 # HuggingFaceDataset base class
1310- if args .dataset_path in VisionArenaDataset .SUPPORTED_DATASET_PATHS :
1320+ if (
1321+ args .dataset_path in VisionArenaDataset .SUPPORTED_DATASET_PATHS
1322+ or args .hf_name in VisionArenaDataset .SUPPORTED_DATASET_PATHS
1323+ ):
13111324 dataset_class = VisionArenaDataset
13121325 args .hf_split = "train"
13131326 args .hf_subset = None
1314- elif args .dataset_path in InstructCoderDataset .SUPPORTED_DATASET_PATHS :
1327+ elif (
1328+ args .dataset_path in InstructCoderDataset .SUPPORTED_DATASET_PATHS
1329+ or args .hf_name in InstructCoderDataset .SUPPORTED_DATASET_PATHS
1330+ ):
13151331 dataset_class = InstructCoderDataset
13161332 args .hf_split = "train"
1317- elif args .dataset_path in MTBenchDataset .SUPPORTED_DATASET_PATHS :
1333+ elif (
1334+ args .dataset_path in MTBenchDataset .SUPPORTED_DATASET_PATHS
1335+ or args .hf_name in MTBenchDataset .SUPPORTED_DATASET_PATHS
1336+ ):
13181337 dataset_class = MTBenchDataset
13191338 args .hf_split = "train"
1320- elif args .dataset_path in ConversationDataset .SUPPORTED_DATASET_PATHS :
1339+ elif (
1340+ args .dataset_path in ConversationDataset .SUPPORTED_DATASET_PATHS
1341+ or args .hf_name in ConversationDataset .SUPPORTED_DATASET_PATHS
1342+ ):
13211343 dataset_class = ConversationDataset
1322- elif args .dataset_path in AIMODataset .SUPPORTED_DATASET_PATHS :
1344+ elif (
1345+ args .dataset_path in AIMODataset .SUPPORTED_DATASET_PATHS
1346+ or args .hf_name in AIMODataset .SUPPORTED_DATASET_PATHS
1347+ ):
13231348 dataset_class = AIMODataset
13241349 args .hf_split = "train"
1325- elif args .dataset_path in NextEditPredictionDataset .SUPPORTED_DATASET_PATHS : # noqa: E501
1350+ elif (
1351+ args .dataset_path
1352+ in NextEditPredictionDataset .SUPPORTED_DATASET_PATHS # noqa: E501
1353+ or args .hf_name in NextEditPredictionDataset .SUPPORTED_DATASET_PATHS
1354+ ):
13261355 dataset_class = NextEditPredictionDataset
13271356 args .hf_split = "train"
1328- elif args .dataset_path in ASRDataset .SUPPORTED_DATASET_PATHS :
1357+ elif (
1358+ args .dataset_path in ASRDataset .SUPPORTED_DATASET_PATHS
1359+ or args .hf_name in ASRDataset .SUPPORTED_DATASET_PATHS
1360+ ):
13291361 dataset_class = ASRDataset
13301362 args .hf_split = "train"
1331- elif args .dataset_path in MLPerfDataset .SUPPORTED_DATASET_PATHS :
1363+ elif (
1364+ args .dataset_path in MLPerfDataset .SUPPORTED_DATASET_PATHS
1365+ or args .hf_name in MLPerfDataset .SUPPORTED_DATASET_PATHS
1366+ ):
13321367 dataset_class = MLPerfDataset
13331368 args .hf_split = "train"
13341369 else :
@@ -1358,6 +1393,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
13581393 dataset_split = args .hf_split ,
13591394 random_seed = args .seed ,
13601395 no_stream = args .no_stream ,
1396+ hf_name = args .hf_name ,
13611397 ).sample (
13621398 num_requests = args .num_prompts ,
13631399 tokenizer = tokenizer ,
@@ -1710,13 +1746,15 @@ def __init__(
17101746 dataset_split : str ,
17111747 no_stream : bool = False ,
17121748 dataset_subset : Optional [str ] = None ,
1749+ hf_name : Optional [str ] = None ,
17131750 ** kwargs ,
17141751 ) -> None :
17151752 super ().__init__ (dataset_path = dataset_path , ** kwargs )
17161753
17171754 self .dataset_split = dataset_split
17181755 self .dataset_subset = dataset_subset
17191756 self .load_stream = not no_stream
1757+ self .hf_name = hf_name or dataset_path
17201758 self .load_data ()
17211759
17221760 def load_data (self ) -> None :
@@ -1827,10 +1865,9 @@ def sample(
18271865 for i , item in enumerate (self .data ):
18281866 if len (sampled_requests ) >= num_requests :
18291867 break
1830- parser_fn = self .SUPPORTED_DATASET_PATHS .get (self .dataset_path )
1868+ parser_fn = self .SUPPORTED_DATASET_PATHS .get (self .hf_name )
18311869 if parser_fn is None :
1832- raise ValueError (
1833- f"Unsupported dataset path: { self .dataset_path } " )
1870+ raise ValueError (f"Unsupported dataset path: { self .hf_name } " )
18341871 prompt = parser_fn (item )
18351872 mm_content = process_image (item ["images" ][0 ])
18361873 prompt_len = len (tokenizer (prompt ).input_ids )
@@ -2099,10 +2136,9 @@ class NextEditPredictionDataset(HuggingFaceDataset):
20992136 def sample (self , tokenizer : PreTrainedTokenizerBase , num_requests : int ,
21002137 request_id_prefix : str = "" ,
21012138 ** kwargs ):
2102- formatting_prompt_func = self .MAPPING_PROMPT_FUNCS .get (
2103- self .dataset_path )
2139+ formatting_prompt_func = self .MAPPING_PROMPT_FUNCS .get (self .hf_name )
21042140 if formatting_prompt_func is None :
2105- raise ValueError (f"Unsupported dataset path: { self .dataset_path } " )
2141+ raise ValueError (f"Unsupported dataset path: { self .hf_name } " )
21062142 samples = []
21072143 for i , sample in enumerate (self .data ):
21082144 sample = formatting_prompt_func (sample )
0 commit comments