@@ -8749,18 +8749,29 @@ def INPUT_TYPES(cls):
87498749 FUNCTION = "combine_masks"
87508750
87518751 def combine_masks(self, mask_a, mask_b, mask_c=None, mask_d=None, mask_e=None, mask_f=None):
8752- masks = [mask_a, mask_b]
8753- if mask_c:
8754- masks.append(mask_c)
8755- if mask_d:
8756- masks.append(mask_d)
8757- if mask_e:
8758- masks.append(mask_e)
8759- if mask_f:
8760- masks.append(mask_f)
8761- combined_mask = torch.sum(torch.stack(masks, dim=0), dim=0)
8762- combined_mask = torch.clamp(combined_mask, 0, 1) # Ensure values are between 0 and 1
8763- return (combined_mask, )
8752+ # Gather all masks in a list
8753+ masks = [m for m in [mask_a, mask_b, mask_c, mask_d, mask_e, mask_f] if m is not None]
8754+
8755+ # Skip any masks that are the known "empty" shape [1, 64, 64] from "Preview" etc
8756+ # (You can also use a sum-of-pixels check, or other logic.)
8757+ valid_masks = [m for m in masks if m.shape != (1, 64, 64)]
8758+ # cstr(f"mask shapes: ... `{valid_masks}`").msg.print()
8759+
8760+ # If no valid masks, decide on a fallback
8761+ if len(valid_masks) == 0:
8762+ # Could return a zeroed-out mask, or just return mask_a, or raise a warning
8763+ # Return mask_a so we don't break the graph
8764+ return (mask_a, )
8765+
8766+ # If there is exactly one valid mask, no combine needed
8767+ if len(valid_masks) == 1:
8768+ return (valid_masks[0], )
8769+
8770+ # Otherwise stack, sum, clamp
8771+ combined_mask = torch.sum(torch.stack(valid_masks, dim=0), dim=0)
8772+ combined_mask = torch.clamp(combined_mask, 0, 1) # Keep values in 0..1
8773+
8774+ return (combined_mask,)
87648775
87658776class WAS_Mask_Combine_Batch:
87668777
0 commit comments