Skip to content

Conversation

@hlky
Copy link
Contributor

@hlky hlky commented Nov 30, 2025

What does this PR do?

import torch
from diffusers import GGUFQuantizationConfig
from diffusers.models import ZImageTransformer2DModel
from huggingface_hub import hf_hub_download

model = ZImageTransformer2DModel.from_single_file(
    hf_hub_download("jayn7/Z-Image-Turbo-GGUF", "z_image_turbo-Q3_K_S.gguf"),
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
)

model = ZImageTransformer2DModel.from_single_file(
    hf_hub_download(
        "Comfy-Org/z_image_turbo",
        "split_files/diffusion_models/z_image_turbo_bf16.safetensors",
    )
)

See https://huggingface.co/Comfy-Org/z_image_turbo/blob/main/z_image_convert_original_to_comfy.py

Fixes #12748

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Vargol
Copy link

Vargol commented Dec 1, 2025

Sorry if this shows as a dupe I though commented this ages ago, but there's no sign of it
Hi, thanks for this, the model I tried loads but t doesn't work, if looks like there's some code in the transformer code that checks the dtype of the weights and gets the int8 storage dtype instead of the GGUF compute dtype and then calls (eventually) torch.nn.Linear with the wrong types.

The code is the TimestepEmbedder forward function

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        weight_dtype = self.mlp[0].weight.dtype
        if weight_dtype.is_floating_point:
            t_freq = t_freq.to(weight_dtype)
        t_emb = self.mlp(t_freq)
        return t_emb

self.mlp[0].weight.dtype returns int8 for a GGUF format model

This leads to the forward code

        t_emb = self.mlp(t_freq)

eventually calling output = torch.nn.functional.linear(inputs, weight, bias) with torch.Float , torch.BFloat16, torch.BFloat16 arguments.
on MPS this fails with Destination NDArray and Accumulator NDArray cannot have different datatype in MPSNDArrayMatrixMultiplication , I don't have any CUDA or other devices to check if this is a generic
issue.

If I hardcode the right type, I can generate an image without issue

        t_freq = t_freq.to(self.mlp[0].compute_dtype)

Presumably It will need to incorporated properly with a attribute check for compute_dtype as part of the dtype setting code rather than my brute force method

@hlky
Copy link
Contributor Author

hlky commented Dec 1, 2025

@Vargol Apologies, I only tested loading. Something like c84e6d7 should work

@Vargol
Copy link

Vargol commented Dec 1, 2025

That looks like it'll work, I'll give a quick test.

@Vargol
Copy link

Vargol commented Dec 1, 2025

Yep - that's worked, no errors only images :-)

@sayakpaul sayakpaul requested a review from DN6 December 3, 2025 07:43
# Match t_embedder output dtype to x for layerwise casting compatibility
adaln_input = t.type_as(x)
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
x[torch.cat(x_inner_pad_mask).to(x.device)] = self.x_pad_token.to(x.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question. Why the device cast here? Is it to fix something else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I meant to remove that, for context this patch was shared in the community to fix layer offloading in one of the training UIs, I was just curious what changes they made and forgot to revert before I started this branch, not sure if it's related to Diffusers offloading or specific to the third party repo. Removed in da06a2c

cap_feats = torch.cat(cap_feats, dim=0)
cap_feats = self.cap_embedder(cap_feats)
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
cap_feats[torch.cat(cap_inner_pad_mask).to(cap_feats.device)] = self.cap_pad_token.to(cap_feats.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question. Why the device cast here? Is it to fix something else?

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hlky 👍🏽

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6 DN6 merged commit 6028613 into huggingface:main Dec 4, 2025
9 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Please support GGUF format models for Z-Image-Turbo

4 participants