From 3719691678fc04c8e58629842f44d82f69660465 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 2 Dec 2025 22:15:08 +0000 Subject: [PATCH 1/3] change values back --- src/transformers/testing_utils.py | 556 ++++++++++++++++++ tests/models/clip/test_modeling_clip.py | 19 +- .../kosmos2_5/test_modeling_kosmos2_5.py | 118 ++-- .../test_modeling_llava_next_video.py | 43 +- 4 files changed, 653 insertions(+), 83 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 89bc2d750b28..1b106280a566 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -47,6 +47,7 @@ from unittest.mock import patch import httpx +import libcst as cst import urllib3 from huggingface_hub import create_repo, delete_repo from packaging import version @@ -3754,6 +3755,561 @@ def patch_testing_methods_to_collect_info(): _patch_with_call_info(unittest.case.TestCase, "assertGreaterEqual", _parse_call_info, target_args=("a", "b")) +############################################################################### +# Expectations Recording System +############################################################################### + + +def record_expectations(pairs=None): + """ + Decorator that auto-updates hardcoded expectations in the test source file. + + This decorator captures actual/predicted values and updates expected values in the source. + Supports both simple values and Expectations() objects for hardware-specific values. + + Usage Pattern 1: Simple pairs (predicted -> expected) + @record_expectations(pairs=[("actual_logits", "expected_logits")]) + @slow + def test_inference(self): + outputs = model(**inputs) + + # Compute actual values + actual_logits = outputs.logits_per_image.detach().cpu() + + # Hardcoded expected - gets auto-updated when UPDATE_EXPECTATIONS=1 + expected_logits = torch.tensor([[24.5701, 19.3049]]) + + torch.testing.assert_close(actual_logits, expected_logits) + + Usage Pattern 2: With Expectations for hardware-specific values + @record_expectations(pairs=[("decoded_text", "expected_decoded_text")]) + @slow + def test_generation(self): + output = model.generate(**inputs) + decoded_text = processor.decode(output[0]) + + # Gets auto-updated with current hardware's actual value + expected_decoded_text = Expectations({ + ("cuda", None): "USER: old cuda output", + ("cpu", None): "USER: old cpu output", + }).get_expectation() + + self.assertEqual(decoded_text, expected_decoded_text) + + Args: + pairs: List of (predicted_var, expected_var) tuples. The decorator captures + predicted_var and updates expected_var in the source file. + + Environment variables: + UPDATE_EXPECTATIONS=1: Enable recording mode to update source file + """ + # Handle various call patterns + if pairs is None: + pairs = [] + + def decorator(test_func): + @functools.wraps(test_func) + def wrapper(self, *args, **kwargs): + update_mode = os.environ.get("UPDATE_EXPECTATIONS", "").lower() in ("1", "yes", "true", "on") + + if not update_mode or not pairs: + # Normal mode or no pairs to track - just run the test + return test_func(self, *args, **kwargs) + + # Recording mode - capture predicted values to update expected values + captured_values = {} # Maps expected_var -> predicted_value + + # Create a tracing function to capture variables + def trace_func(frame, event, arg): + if event == "return" and frame.f_code.co_name == test_func.__name__: + # Capture predicted variables and map to expected variables + for predicted_var, expected_var in pairs: + if predicted_var in frame.f_locals: + # Store the predicted value under the expected variable name + captured_values[expected_var] = frame.f_locals[predicted_var] + return trace_func + + # Run test with sys.settrace to capture variables + # In recording mode, we want to continue past assertion failures to capture all values + old_trace = sys.gettrace() + sys.settrace(trace_func) + test_exception = None + result = None + + # Monkey-patch assertion methods to not fail immediately when recording + original_fail = self.fail if hasattr(self, "fail") else None + assertion_errors = [] + + def deferred_fail(msg=None): + """Collect assertion failures instead of raising immediately.""" + assertion_errors.append(AssertionError(msg) if msg else AssertionError()) + + if original_fail: + self.fail = deferred_fail + + try: + result = test_func(self, *args, **kwargs) + except AssertionError as e: + # Capture assertion errors but continue + assertion_errors.append(e) + except Exception as e: + # Capture non-assertion exceptions + test_exception = e + finally: + sys.settrace(old_trace) + # Restore original fail method + if original_fail: + self.fail = original_fail + + # Combine assertion errors with test exception + if assertion_errors and not test_exception: + test_exception = assertion_errors[0] # Report first assertion error + + # Update source file if we captured values (even if test failed) + if captured_values: + try: + _update_test_source_with_libcst(test_func, captured_values, update_expectations_objects=True) + logger.info("✅ Expectations updated successfully! Run test again in normal mode to verify.") + except Exception as e: + logger.error(f"Failed to update test source: {e}") + import traceback + + logger.error(traceback.format_exc()) + # If source update fails, raise that error + raise + + # Now raise the original test exception if there was one + if test_exception is not None: + logger.info( + "Note: Test failed with old expectations. " + "This is expected on first run. Re-run without UPDATE_EXPECTATIONS=1" + ) + raise test_exception + + return result + + return wrapper + + return decorator + + +def _update_test_source_with_libcst(test_func, captured_values, update_expectations_objects=False): + """ + Update test source file using libcst to preserve formatting. + + Args: + test_func: The test function to update + captured_values: Dict mapping variable names to their captured values + update_expectations_objects: If True, also update Expectations() dict values + + Returns: + Path to backup file if created, None otherwise + """ + # Get source file path + source_file = Path(inspect.getsourcefile(test_func)) + if not source_file.exists(): + raise FileNotFoundError(f"Source file not found: {source_file}") + + logger.info(f"Updating expectations in {source_file}") + + # Read source code + source_code = source_file.read_text() + + # Parse with libcst + try: + tree = cst.parse_module(source_code) + except Exception as e: + raise RuntimeError(f"Failed to parse {source_file}: {e}") + + # Get current device properties for Expectations updates + device_props = get_device_properties() + + # Transform the tree + transformer = ExpectationUpdater( + test_func.__name__, captured_values, device_props=device_props, update_expectations=update_expectations_objects + ) + modified_tree = tree.visit(transformer) + + if not transformer.updated: + logger.warning(f"No variables were updated in {test_func.__name__}") + return None + + # Write back to file + new_code = modified_tree.code + + # Create temporary backup for safety during write + backup_file = source_file.with_suffix(".py.bak") + if backup_file.exists(): + backup_file.unlink() + source_file.rename(backup_file) + + try: + source_file.write_text(new_code) + logger.info(f"Successfully updated {len(transformer.updated)} variable(s) in {source_file.name}") + logger.info(f"Updated variables: {', '.join(transformer.updated)}") + + # Delete backup immediately after successful write + # The backup is for write protection, not test failure protection + if backup_file.exists(): + backup_file.unlink() + logger.info("Backup cleaned up after successful write") + + return None # No backup to clean up later + except Exception as e: + # Restore backup on write error + if backup_file.exists(): + backup_file.rename(source_file) + raise RuntimeError(f"Failed to write updated source: {e}") + + +class ExpectationUpdater(cst.CSTTransformer): + """ + libcst transformer to update expectation variable assignments. + + This preserves all formatting, comments, and whitespace while only + modifying the values of specified variables. + + Supports both simple assignments and Expectations() objects. + """ + + def __init__(self, test_name, values, device_props=None, update_expectations=False): + super().__init__() + self.test_name = test_name + self.values = values + self.device_props = device_props or (None, None, None) + self.update_expectations = update_expectations + self.in_target_function = False + self.updated = [] + + def visit_FunctionDef(self, node): + """Track when we enter the target test function.""" + if node.name.value == self.test_name: + self.in_target_function = True + + def leave_FunctionDef(self, original_node, updated_node): + """Track when we leave the target test function.""" + if original_node.name.value == self.test_name: + self.in_target_function = False + return updated_node + + def leave_SimpleStatementLine(self, original_node, updated_node): + """Add fmt: off comment to updated assignment lines.""" + if not self.in_target_function: + return updated_node + + # Check if this line contains an updated assignment + for stmt in updated_node.body: + if isinstance(stmt, cst.Assign): + for target in stmt.targets: + if isinstance(target.target, cst.Name): + var_name = target.target.value + if var_name in self.values: + # Check if already has fmt: off comment + has_fmt_off = False + if updated_node.trailing_whitespace.comment: + has_fmt_off = "fmt: off" in updated_node.trailing_whitespace.comment.value + + if not has_fmt_off: + # Add trailing comment + new_trailing = cst.TrailingWhitespace( + whitespace=cst.SimpleWhitespace(" "), + comment=cst.Comment("# fmt: off"), + newline=updated_node.trailing_whitespace.newline, + ) + return updated_node.with_changes(trailing_whitespace=new_trailing) + + return updated_node + + def leave_Assign(self, original_node, updated_node): + """Update assignment statements for tracked variables.""" + if not self.in_target_function: + return updated_node + + # Check if this assigns to one of our tracked variables + for target in updated_node.targets: + if isinstance(target.target, cst.Name): + var_name = target.target.value + + if var_name in self.values: + new_value_data = self.values[var_name] + + # Check if this is an Expectations().get_expectation() call + if self.update_expectations and self._is_expectations_call(updated_node.value): + # Update the Expectations dict for current hardware + new_value = self._update_expectations_call(updated_node.value, new_value_data) + else: + # Simple replacement + new_value = self._python_value_to_cst(new_value_data) + + if new_value is not None: + self.updated.append(var_name) + return updated_node.with_changes(value=new_value) + + return updated_node + + def _is_expectations_call(self, node): + """Check if a node is an Expectations().get_expectation() call.""" + # Look for pattern: Expectations({...}).get_expectation() + if isinstance(node, cst.Call): + if isinstance(node.func, cst.Attribute): + if node.func.attr.value == "get_expectation": + if isinstance(node.func.value, cst.Call): + if isinstance(node.func.value.func, cst.Name): + if node.func.value.func.value == "Expectations": + return True + return False + + def _update_expectations_call(self, node, new_value): + """Update an Expectations({...}).get_expectation() call with new value for current hardware.""" + # node is: Expectations({...}).get_expectation() + # node.func is: Expectations({...}).get_expectation (Attribute) + # node.func.value is: Expectations({...}) (Call) + # node.func.value.args[0] is: {...} (Dict) + + expectations_call = node.func.value + if not expectations_call.args: + # No dict argument, can't update + return None + + dict_arg = expectations_call.args[0].value + if not isinstance(dict_arg, cst.Dict): + return None + + # Get current device key + device_type, major, minor = self.device_props + + # Create the key tuple we're looking for + # Format: ("cuda", None) or ("cuda", (12, 1)) + if device_type is None: + target_key_repr = "(None, None)" + elif major is not None and minor is not None: + target_key_repr = f'("{device_type}", ({major}, {minor}))' + else: + target_key_repr = f'("{device_type}", None)' + + # Update or add the entry for current hardware + # Strategy: prefer updating existing (device, None) over adding new versioned entry + updated_elements = [] + found = False + generic_entry_index = None + + # First pass: look for exact match or generic match + for idx, element in enumerate(dict_arg.elements): + if isinstance(element, cst.DictElement): + exact_match = self._key_matches_device(element.key, device_type, major, minor) + generic_match = self._key_matches_device(element.key, device_type, None, None) + + if exact_match: + # Exact match found - update this entry + new_value_node = self._python_value_to_cst(new_value) + updated_elements.append(element.with_changes(value=new_value_node)) + found = True + logger.info(f"Updated existing Expectations entry (exact match) for {target_key_repr}") + elif generic_match and not found: + # Generic match found - remember it in case we don't find exact match + generic_entry_index = len(updated_elements) + updated_elements.append(element) # Add as-is for now + else: + # Keep existing element unchanged + updated_elements.append(element) + else: + # Keep other elements (like trailing commas) + updated_elements.append(element) + + # If exact match not found but generic match found, update the generic entry + if not found and generic_entry_index is not None: + generic_element = updated_elements[generic_entry_index] + new_value_node = self._python_value_to_cst(new_value) + updated_elements[generic_entry_index] = generic_element.with_changes(value=new_value_node) + found = True + logger.info(f"Updated existing Expectations entry (generic match) for ({device_type}, None)") + + # If neither found, add new entry + if not found: + # Parse the key tuple + key_node = cst.parse_expression(target_key_repr) + value_node = self._python_value_to_cst(new_value) + new_element = cst.DictElement(key=key_node, value=value_node) + updated_elements.append(new_element) + logger.info(f"Added new Expectations entry for {target_key_repr}") + + # Create updated dict + updated_dict = dict_arg.with_changes(elements=updated_elements) + + # Create updated Expectations call + updated_expectations_call = expectations_call.with_changes(args=[cst.Arg(value=updated_dict)]) + + # Return the full call chain + return node.with_changes(func=node.func.with_changes(value=updated_expectations_call)) + + def _key_matches_device(self, key_node, device_type, major, minor): + """ + Check if a key tuple matches the current device properties. + + Key formats supported: + - ("cuda", None) - matches any CUDA device + - ("cuda", (8, 6)) - matches CUDA compute capability 8.6 exactly + - ("cuda", (8, None)) - matches any CUDA 8.x device + """ + # Convert the CST node to code and evaluate it safely + try: + key_code = key_node.visit(cst.CodegenState(default_newline="\n")).code + # Parse the tuple: ("device", None) or ("device", (major, minor)) or ("device", (major, None)) + # We'll do string parsing since eval is unsafe + + # Extract device type from key + if '"' in key_code or "'" in key_code: + # Extract string between quotes + import re + + device_match = re.search(r'["\'](\w+)["\']', key_code) + if device_match: + key_device = device_match.group(1) + else: + return False + + # Check if device types match + if device_type != key_device: + return False + + # Now check version part - need to handle (major, None) pattern + if major is not None: + # Check for exact match: (major, minor) + if minor is not None: + version_pattern = rf"\(\s*{major}\s*,\s*{minor}\s*\)" + if re.search(version_pattern, key_code): + return True + + # Check for major-only match: (major, None) + major_only_pattern = rf"\(\s*{major}\s*,\s*None\s*\)" + if re.search(major_only_pattern, key_code): + return True + + # If key has (different_major, ...), don't match + # But if key is just (device, None), it's a generic match + if "None" in key_code: + second_part = key_code.split(",", 1)[1] if "," in key_code else "" + # Check if it's ("device", None) format (not a tuple version) + if "(" not in second_part or second_part.strip().startswith("None"): + return True + return False + else: + # Looking for ("device", None) when we don't have major/minor + if "None" in key_code: + second_part = key_code.split(",", 1)[1] if "," in key_code else "" + if "(" not in second_part or "None" in second_part: + return True + return False + return False + except Exception: + # Fall back to string matching if parsing fails + key_str = str(key_node) + if device_type and device_type in key_str: + if major is not None: + # Check for (major, minor) or (major, None) + if minor is not None and f"({major}, {minor})" in key_str: + return True + if f"({major}, None)" in key_str: + return True + return "None" in key_str + return False + + def _python_value_to_cst(self, value): + """ + Convert a Python value to a libcst node. + + Handles torch.Tensor, numpy arrays, lists, dicts, primitives, etc. + """ + if is_torch_available(): + import torch + + if isinstance(value, torch.Tensor): + return self._tensor_to_cst(value) + + try: + import numpy as np + + if isinstance(value, np.ndarray): + return self._array_to_cst(value) + except ImportError: + pass + + if isinstance(value, str): + # Use repr() to properly escape special characters (\n, \t, quotes, etc.) + # Prefer double quotes over single quotes when possible + repr_str = repr(value) + # repr() uses single quotes by default, convert to double quotes if no conflicts + if repr_str.startswith("'") and repr_str.endswith("'") and '"' not in value: + # Replace single quotes with double quotes + inner_str = repr_str[1:-1] + formatted_str = f'"{inner_str}"' + return cst.SimpleString(formatted_str) + else: + # Use repr() as-is (handles edge cases like strings with both quote types) + return cst.SimpleString(repr_str) + elif isinstance(value, bool): + # Check bool before int since bool is subclass of int + return cst.Name("True" if value else "False") + elif isinstance(value, int): + return cst.Integer(str(value)) + elif isinstance(value, float): + # Use repr() for float to get proper Python literal format + # This handles special cases like inf, -inf, nan correctly + return cst.parse_expression(repr(value)) + elif isinstance(value, list): + return self._list_to_cst(value) + elif isinstance(value, dict): + return self._dict_to_cst(value) + elif value is None: + return cst.Name("None") + else: + # Fallback: try to represent as string + logger.warning(f"Unknown type {type(value)}, using repr()") + return cst.SimpleString(f'"{repr(value)}"') + + def _tensor_to_cst(self, tensor): + """Convert torch.Tensor to torch.tensor([...]) call.""" + data = tensor.detach().cpu().tolist() + + # Create torch.tensor([data]) call + return cst.Call( + func=cst.Attribute(value=cst.Name("torch"), attr=cst.Name("tensor")), + args=[cst.Arg(value=self._list_to_cst(data))], + ) + + def _array_to_cst(self, array): + """Convert numpy array to np.array([...]) call.""" + data = array.tolist() + + return cst.Call( + func=cst.Attribute(value=cst.Name("np"), attr=cst.Name("array")), + args=[cst.Arg(value=self._list_to_cst(data))], + ) + + def _list_to_cst(self, lst): + """Convert Python list to libcst list node.""" + elements = [] + for item in lst: + elements.append(cst.Element(value=self._python_value_to_cst(item))) + + return cst.List(elements=elements) + + def _dict_to_cst(self, dct): + """Convert Python dict to libcst dict node.""" + elements = [] + for key, value in dct.items(): + elements.append( + cst.DictElement(key=self._python_value_to_cst(key), value=self._python_value_to_cst(value)) + ) + + return cst.Dict(elements=elements) + + +############################################################################### +# End Expectations Recording System +############################################################################### + + def torchrun(script: str, nproc_per_node: int, is_torchrun: bool = True, env: dict | None = None): """Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary.""" with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp: diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index 4ce1b9faa6aa..7b4a93e9195e 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -24,6 +24,7 @@ from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig from transformers.testing_utils import ( + record_expectations, require_torch, require_vision, slow, @@ -670,6 +671,7 @@ def prepare_img(): @require_vision @require_torch class CLIPModelIntegrationTest(unittest.TestCase): + @record_expectations(pairs=[("actual_logits", "expected_logits")]) @slow def test_inference(self): model_name = "openai/clip-vit-base-patch32" @@ -695,10 +697,13 @@ def test_inference(self): torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])), ) - expected_logits = torch.tensor([[24.5701, 19.3049]], device=torch_device) + actual_logits = outputs.logits_per_image.detach().cpu() - torch.testing.assert_close(outputs.logits_per_image, expected_logits, rtol=1e-3, atol=1e-3) + expected_logits = torch.tensor([[24.570053100585938, 19.304885864257812]]) # fmt: off + torch.testing.assert_close(actual_logits, expected_logits, rtol=1e-3, atol=1e-3) + + @record_expectations(pairs=[("actual_slice", "expected_slice")]) @slow def test_inference_interpolate_pos_encoding(self): # CLIP models have an `interpolate_pos_encoding` argument in their forward method, @@ -728,10 +733,8 @@ def test_inference_interpolate_pos_encoding(self): self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) - expected_slice = torch.tensor( - [[-0.1538, 0.0322, -0.3235], [0.2893, 0.1135, -0.5708], [0.0461, 0.1540, -0.6018]] - ).to(torch_device) + actual_slice = outputs.vision_model_output.last_hidden_state[0, :3, :3].detach().cpu() - torch.testing.assert_close( - outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, rtol=6e-3, atol=4e-4 - ) + expected_slice = torch.tensor([[-0.15380290150642395, 0.0321517139673233, -0.32353174686431885], [0.2893178462982178, 0.11354223638772964, -0.5707830190658569], [0.046103253960609436, 0.15403859317302704, -0.6017531156539917]]) # fmt: off + + torch.testing.assert_close(actual_slice, expected_slice, rtol=6e-3, atol=4e-4) diff --git a/tests/models/kosmos2_5/test_modeling_kosmos2_5.py b/tests/models/kosmos2_5/test_modeling_kosmos2_5.py index 105fba5e596b..f84c093393d3 100644 --- a/tests/models/kosmos2_5/test_modeling_kosmos2_5.py +++ b/tests/models/kosmos2_5/test_modeling_kosmos2_5.py @@ -30,6 +30,8 @@ Kosmos2_5VisionConfig, ) from transformers.testing_utils import ( + Expectations, + record_expectations, require_flash_attn, require_torch, require_torch_accelerator, @@ -553,6 +555,9 @@ def run_example(self, prompt, image, model, processor): return generated_ids, generated_text + @record_expectations( + pairs=[("generated_text_ocr", "expected_text_ocr"), ("generated_text_md", "expected_text_md")] + ) def test_eager(self): url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png" image = Image.open(requests.get(url, stream=True).raw) @@ -563,33 +568,32 @@ def test_eager(self): repo, device_map=torch_device, dtype=dtype, attn_implementation="eager" ) processor = AutoProcessor.from_pretrained(repo) + + # Test 1: OCR prompt prompt = "" - generated_ids, generated_text = self.run_example(prompt, image, model, processor) - EXPECTED_TEXT = { - 7: [ - "1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n" - ], - 8: [ - "1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n" - ], - } + generated_ids, generated_text_ocr = self.run_example(prompt, image, model, processor) + expected_text_ocr = Expectations( + { + ("cuda", (7, None)): ["1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n"], + ("cuda", (8, None)): ["1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n"], + }).get_expectation() # fmt: off - self.assertListEqual(generated_text, EXPECTED_TEXT[self.cuda_compute_capability_major_version]) + self.assertListEqual(generated_text_ocr, expected_text_ocr) + # Test 2: Markdown prompt prompt = "" - generated_ids, generated_text = self.run_example(prompt, image, model, processor) - - EXPECTED_TEXT = { - 7: [ - "- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n\nCard Payment 50,000" - ], - 8: [ - "- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n\nCard Payment 50,000" - ], - } + generated_ids, generated_text_md = self.run_example(prompt, image, model, processor) + expected_text_md = Expectations( + { + ("cuda", (7, None)): ["- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n\nCard Payment 50,000"], + ("cuda", (8, None)): ["- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n\nCard Payment 50,000"], + }).get_expectation() # fmt: off - self.assertListEqual(generated_text, EXPECTED_TEXT[self.cuda_compute_capability_major_version]) + self.assertListEqual(generated_text_md, expected_text_md) + @record_expectations( + pairs=[("generated_text_ocr", "expected_text_ocr"), ("generated_text_md", "expected_text_md")] + ) def test_sdpa(self): url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png" image = Image.open(requests.get(url, stream=True).raw) @@ -600,33 +604,32 @@ def test_sdpa(self): repo, device_map=torch_device, dtype=dtype, attn_implementation="sdpa" ) processor = AutoProcessor.from_pretrained(repo) + + # Test 1: OCR prompt prompt = "" - generated_ids, generated_text = self.run_example(prompt, image, model, processor) - EXPECTED_TEXT = { - 7: [ - "1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n", - ], - 8: [ - "1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n" - ], - } + generated_ids, generated_text_ocr = self.run_example(prompt, image, model, processor) + expected_text_ocr = Expectations( + { + ("cuda", (7, None)): ["1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n"], + ("cuda", (8, None)): ["1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n"], + }).get_expectation() # fmt: off - self.assertListEqual(generated_text, EXPECTED_TEXT[self.cuda_compute_capability_major_version]) + self.assertListEqual(generated_text_ocr, expected_text_ocr) + # Test 2: Markdown prompt prompt = "" - generated_ids, generated_text = self.run_example(prompt, image, model, processor) - - EXPECTED_TEXT = { - 7: [ - "- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n\nCard Payment 50,000" - ], - 8: [ - "- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n\nCard Payment 50,000" - ], - } + generated_ids, generated_text_md = self.run_example(prompt, image, model, processor) + expected_text_md = Expectations( + { + ("cuda", (7, None)): ["- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n\nCard Payment 50,000"], + ("cuda", (8, None)): ["- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n\nCard Payment 50,000"], + }).get_expectation() # fmt: off - self.assertListEqual(generated_text, EXPECTED_TEXT[self.cuda_compute_capability_major_version]) + self.assertListEqual(generated_text_md, expected_text_md) + @record_expectations( + pairs=[("generated_text_ocr", "expected_text_ocr"), ("generated_text_md", "expected_text_md")] + ) @require_flash_attn @require_torch_accelerator @pytest.mark.flash_attn_test @@ -644,19 +647,26 @@ def test_FA2(self): attn_implementation="flash_attention_2", ) processor = AutoProcessor.from_pretrained(repo) + + # Test 1: OCR prompt prompt = "" - generated_ids, generated_text = self.run_example(prompt, image, model, processor) - EXPECTED_TEXT = [ - "1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n" - ] + generated_ids, generated_text_ocr = self.run_example(prompt, image, model, processor) + # Flash Attention 2 doesn't vary by compute capability for this test + expected_text_ocr = Expectations( + { + ("cuda", None): ["1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n"], + }).get_expectation() # fmt: off - self.assertListEqual(generated_text, EXPECTED_TEXT) + self.assertListEqual(generated_text_ocr, expected_text_ocr) + # Test 2: Markdown prompt prompt = "" - generated_ids, generated_text = self.run_example(prompt, image, model, processor) - # A10 gives the 1st one, but A100 gives the 2nd one - EXPECTED_TEXT = [ - "- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
\nSub Total\n\n45,455\n
\nPB1 (10%)\n\n4,545\n
\nRounding\n\n0\n
\n\nTotal\n\n\n\n50,000\n\n
\n\nCard Payment 50,000", - "- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n", - ] - self.assertIn(generated_text[0], EXPECTED_TEXT) + generated_ids, generated_text_md = self.run_example(prompt, image, model, processor) + # A10 gives the 1st one, but A100 gives the 2nd one - using assertIn for this variance + expected_text_md = Expectations( + { + ("cuda", None): ["- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n"], + }).get_expectation() # fmt: off + + # Using assertIn since A10 and A100 give different outputs for markdown + self.assertListEqual(generated_text_md, expected_text_md) diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index aba7b644f733..eaee4a0d0947 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -32,6 +32,7 @@ from transformers.testing_utils import ( Expectations, cleanup, + record_expectations, require_bitsandbytes, require_torch, slow, @@ -348,13 +349,13 @@ def setUp(self): def tearDown(self): cleanup(torch_device, gc_collect=True) + @record_expectations(pairs=[("decoded_text", "expected_decoded_text")]) @slow @require_bitsandbytes def test_small_model_integration_test(self): model = LlavaNextVideoForConditionalGeneration.from_pretrained( "llava-hf/LLaVA-NeXT-Video-7B-hf", quantization_config=BitsAndBytesConfig(load_in_4bit=True), - cache_dir="./", ) inputs = self.processor(text=self.prompt_video, videos=self.video, return_tensors="pt") @@ -365,24 +366,25 @@ def test_small_model_integration_test(self): # verify generation output = model.generate(**inputs, do_sample=False, max_new_tokens=40) + decoded_text = self.processor.decode(output[0], skip_special_tokens=True) + + # Auto-updated when running with UPDATE_EXPECTATIONS=1 expected_decoded_text = Expectations( { - ("cuda", None): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems", + ("cuda", (7, 5)): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems", ("xpu", None): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems", - ("rocm", (9, 5)): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and adorable behavior of the young child. The child is seen reading a book, but instead of turning the pages like one would typically do, they", - } - ).get_expectation() # fmt: off + ("rocm", (9, 5)): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and adorable behavior of the young child. The child is seen reading a book, but instead of turning the pages like one would typically do, they",("cuda", (8, 6)): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems" + }).get_expectation() # fmt: off - decoded_text = self.processor.decode(output[0], skip_special_tokens=True) self.assertEqual(decoded_text, expected_decoded_text) + @record_expectations(pairs=[("decoded_text", "expected_decoded_text")]) @slow @require_bitsandbytes def test_small_model_integration_test_batch(self): model = LlavaNextVideoForConditionalGeneration.from_pretrained( "llava-hf/LLaVA-NeXT-Video-7B-hf", quantization_config=BitsAndBytesConfig(load_in_4bit=True), - cache_dir="./", ) inputs = self.processor( @@ -395,24 +397,23 @@ def test_small_model_integration_test_batch(self): output = model.generate(**inputs, do_sample=False, max_new_tokens=20) decoded_text = self.processor.batch_decode(output, skip_special_tokens=True) + # Auto-updated when running with UPDATE_EXPECTATIONS=1 expected_decoded_text = Expectations( { - ("xpu", None): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a", - ("cuda", None): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a", - ("rocm", (9, 5)): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and adorable behavior of the young child. The", - } - ).get_expectation() # fmt: off - EXPECTED_DECODED_TEXT = [expected_decoded_text, expected_decoded_text] + ("xpu", None): ["USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a", "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a"], + ("cuda", None): ["USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a", "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a"], + ("rocm", (9, 5)): ["USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and adorable behavior of the young child. The", "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and adorable behavior of the young child. The"], + }).get_expectation() # fmt: off - self.assertEqual(decoded_text, EXPECTED_DECODED_TEXT) + self.assertEqual(decoded_text, expected_decoded_text) + @record_expectations(pairs=[("decoded_text", "expected_decoded_text")]) @slow @require_bitsandbytes def test_small_model_integration_test_batch_different_vision_types(self): model = LlavaNextVideoForConditionalGeneration.from_pretrained( "llava-hf/LLaVA-NeXT-Video-7B-hf", quantization_config=BitsAndBytesConfig(load_in_4bit=True), - cache_dir="./", ) inputs = self.processor( @@ -431,16 +432,17 @@ def test_small_model_integration_test_batch_different_vision_types(self): # verify generation output = model.generate(**inputs, do_sample=False, max_new_tokens=50) - EXPECTED_DECODED_TEXT = Expectations( + decoded_text = self.processor.decode(output[0], skip_special_tokens=True) + + # Auto-updated when running with UPDATE_EXPECTATIONS=1 + expected_decoded_text = Expectations( { ("xpu", None): 'USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a graphical representation of a machine learning model\'s performance on a task, likely related to natural language processing or text understanding. It shows a scatter plot with two axes, one labeled "BLIP-2"', ("rocm", (9, 5)): "USER: \nWhat is shown in this image? ASSISTANT: The image displays a chart that appears to be a comparison of different models or versions of a machine learning (ML) model, likely a neural network, based on their performance on a task or dataset. The chart is a scatter plot with axes labeled", ("cuda", None): 'USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a graphical representation of a machine learning model\'s performance on a task, likely related to natural language processing or text understanding. It shows a scatter plot with two axes, one labeled "BLIP-2"', - } - ).get_expectation() # fmt: off + }).get_expectation() # fmt: off - decoded_text = self.processor.decode(output[0], skip_special_tokens=True) - self.assertEqual(decoded_text, EXPECTED_DECODED_TEXT) + self.assertEqual(decoded_text, expected_decoded_text) @slow @require_bitsandbytes @@ -448,7 +450,6 @@ def test_small_model_integration_test_batch_matches_single(self): model = LlavaNextVideoForConditionalGeneration.from_pretrained( "llava-hf/LLaVA-NeXT-Video-7B-hf", quantization_config=BitsAndBytesConfig(load_in_4bit=True), - cache_dir="./", ) inputs_batched = self.processor( From 7b312889d591c7277f6c55b3cd61eb2ae7c84fe3 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 3 Dec 2025 15:59:32 +0000 Subject: [PATCH 2/3] round to 1e-4 --- src/transformers/testing_utils.py | 11 ++++++++--- tests/models/clip/test_modeling_clip.py | 4 ++-- .../test_modeling_llava_next_video.py | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 1b106280a566..7c00af94261e 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -23,6 +23,7 @@ import inspect import json import logging +import math import multiprocessing import os import re @@ -4253,9 +4254,13 @@ def _python_value_to_cst(self, value): elif isinstance(value, int): return cst.Integer(str(value)) elif isinstance(value, float): - # Use repr() for float to get proper Python literal format - # This handles special cases like inf, -inf, nan correctly - return cst.parse_expression(repr(value)) + # Round to 4 decimal places (1e-4 precision) for cleaner output + # Keep special values (inf, -inf, nan) as-is + if math.isnan(value) or math.isinf(value): + return cst.parse_expression(repr(value)) + else: + rounded = round(value, 4) + return cst.parse_expression(repr(rounded)) elif isinstance(value, list): return self._list_to_cst(value) elif isinstance(value, dict): diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index 7b4a93e9195e..d62b877dbd40 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -699,7 +699,7 @@ def test_inference(self): actual_logits = outputs.logits_per_image.detach().cpu() - expected_logits = torch.tensor([[24.570053100585938, 19.304885864257812]]) # fmt: off + expected_logits = torch.tensor([[24.5701, 19.3049]]) # fmt: off torch.testing.assert_close(actual_logits, expected_logits, rtol=1e-3, atol=1e-3) @@ -735,6 +735,6 @@ def test_inference_interpolate_pos_encoding(self): actual_slice = outputs.vision_model_output.last_hidden_state[0, :3, :3].detach().cpu() - expected_slice = torch.tensor([[-0.15380290150642395, 0.0321517139673233, -0.32353174686431885], [0.2893178462982178, 0.11354223638772964, -0.5707830190658569], [0.046103253960609436, 0.15403859317302704, -0.6017531156539917]]) # fmt: off + expected_slice = torch.tensor([[-0.1538, 0.0322, -0.3235], [0.2893, 0.1135, -0.5708], [0.0461, 0.154, -0.6018]]) # fmt: off torch.testing.assert_close(actual_slice, expected_slice, rtol=6e-3, atol=4e-4) diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index eaee4a0d0947..b6bbae90b7b4 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -371,7 +371,7 @@ def test_small_model_integration_test(self): # Auto-updated when running with UPDATE_EXPECTATIONS=1 expected_decoded_text = Expectations( { - ("cuda", (7, 5)): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems", + ("cuda", None): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems", ("xpu", None): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems", ("rocm", (9, 5)): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and adorable behavior of the young child. The child is seen reading a book, but instead of turning the pages like one would typically do, they",("cuda", (8, 6)): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems" }).get_expectation() # fmt: off From a5b8c571d3e1513be42994ea8fc59b38c9e7916a Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 3 Dec 2025 21:26:33 +0000 Subject: [PATCH 3/3] cleanup --- src/transformers/testing_utils.py | 10 ---------- tests/models/kosmos2_5/test_modeling_kosmos2_5.py | 8 -------- .../llava_next_video/test_modeling_llava_next_video.py | 5 +---- 3 files changed, 1 insertion(+), 22 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 7c00af94261e..0fc3a6d84236 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -3756,11 +3756,6 @@ def patch_testing_methods_to_collect_info(): _patch_with_call_info(unittest.case.TestCase, "assertGreaterEqual", _parse_call_info, target_args=("a", "b")) -############################################################################### -# Expectations Recording System -############################################################################### - - def record_expectations(pairs=None): """ Decorator that auto-updates hardcoded expectations in the test source file. @@ -4310,11 +4305,6 @@ def _dict_to_cst(self, dct): return cst.Dict(elements=elements) -############################################################################### -# End Expectations Recording System -############################################################################### - - def torchrun(script: str, nproc_per_node: int, is_torchrun: bool = True, env: dict | None = None): """Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary.""" with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp: diff --git a/tests/models/kosmos2_5/test_modeling_kosmos2_5.py b/tests/models/kosmos2_5/test_modeling_kosmos2_5.py index f84c093393d3..9ffba996dd87 100644 --- a/tests/models/kosmos2_5/test_modeling_kosmos2_5.py +++ b/tests/models/kosmos2_5/test_modeling_kosmos2_5.py @@ -533,13 +533,6 @@ def _prepare_image_embeds_position_mask(input_ids, pad_size): class Kosmos2_5ModelIntegrationTest(unittest.TestCase): # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None - - @classmethod - def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] def run_example(self, prompt, image, model, processor): inputs = processor(text=prompt, images=image, return_tensors="pt") @@ -662,7 +655,6 @@ def test_FA2(self): # Test 2: Markdown prompt prompt = "" generated_ids, generated_text_md = self.run_example(prompt, image, model, processor) - # A10 gives the 1st one, but A100 gives the 2nd one - using assertIn for this variance expected_text_md = Expectations( { ("cuda", None): ["- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n"], diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index b6bbae90b7b4..ddebf72a09e0 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -368,12 +368,11 @@ def test_small_model_integration_test(self): output = model.generate(**inputs, do_sample=False, max_new_tokens=40) decoded_text = self.processor.decode(output[0], skip_special_tokens=True) - # Auto-updated when running with UPDATE_EXPECTATIONS=1 expected_decoded_text = Expectations( { ("cuda", None): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems", ("xpu", None): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems", - ("rocm", (9, 5)): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and adorable behavior of the young child. The child is seen reading a book, but instead of turning the pages like one would typically do, they",("cuda", (8, 6)): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems" + ("rocm", (9, 5)): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and adorable behavior of the young child. The child is seen reading a book, but instead of turning the pages like one would typically do, they", }).get_expectation() # fmt: off self.assertEqual(decoded_text, expected_decoded_text) @@ -397,7 +396,6 @@ def test_small_model_integration_test_batch(self): output = model.generate(**inputs, do_sample=False, max_new_tokens=20) decoded_text = self.processor.batch_decode(output, skip_special_tokens=True) - # Auto-updated when running with UPDATE_EXPECTATIONS=1 expected_decoded_text = Expectations( { ("xpu", None): ["USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a", "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a"], @@ -434,7 +432,6 @@ def test_small_model_integration_test_batch_different_vision_types(self): output = model.generate(**inputs, do_sample=False, max_new_tokens=50) decoded_text = self.processor.decode(output[0], skip_special_tokens=True) - # Auto-updated when running with UPDATE_EXPECTATIONS=1 expected_decoded_text = Expectations( { ("xpu", None): 'USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a graphical representation of a machine learning model\'s performance on a task, likely related to natural language processing or text understanding. It shows a scatter plot with two axes, one labeled "BLIP-2"',