Skip to content
Closed
158 changes: 155 additions & 3 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8957,6 +8957,155 @@
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))


@ModelBase.register("MegrezMoEForCausalLM")
class MegrezMoEModel(TextModel):
model_arch = gguf.MODEL_ARCH.MEGREZ_MOE

def set_vocab(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)

tokpre = self.get_vocab_base_pre(tokenizer)
merges = []
vocab = {}
mergeable_ranks = getattr(tokenizer, "mergeable_ranks", {})
for token, rank in mergeable_ranks.items():
vocab[QwenModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
if len(merged) == 2:
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))

vocab_size = self.hparams["vocab_size"]
assert tokenizer.vocab_size == vocab_size
special_tokens = getattr(tokenizer, "special_tokens", {})
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items() if id_ is not None}
tokens: list[str] = []
toktypes: list[int] = []
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token = reverse_vocab[i]
if token is None:
tokens.append(f"[PAD{i}]")
else:
tokens.append(str(token))
if i in special_tokens.values():
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.NORMAL)

self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_token_merges(merges)

special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.add_to_gguf(self.gguf_writer)
# BOS token fix if needed
# self.gguf_writer.add_bos_token_id(<id>)

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams

num_experts = hparams.get("num_experts")
if num_experts is not None:
self.gguf_writer.add_expert_count(int(num_experts))
intermediate_size = hparams.get("intermediate_size")
if intermediate_size is not None:
self.gguf_writer.add_expert_shared_feed_forward_length(int(intermediate_size))

moe_intermediate_size = hparams.get("moe_intermediate_size")
if moe_intermediate_size is not None and isinstance(moe_intermediate_size, (list, tuple)) and len(moe_intermediate_size) > 0:
assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size)
self.gguf_writer.add_expert_feed_forward_length(int(moe_intermediate_size[0]))

moe_topk = hparams.get("moe_topk")
if moe_topk is not None and isinstance(moe_topk, (list, tuple)) and len(moe_topk) > 0:
assert all(topk == moe_topk[0] for topk in moe_topk)
self.gguf_writer.add_expert_used_count(int(moe_topk[0]))

moe_shared_expert = hparams.get("num_shared_expert")
if moe_shared_expert is not None and isinstance(moe_shared_expert, (list, tuple)) and len(moe_shared_expert) > 0:
assert all(n == moe_shared_expert[0] for n in moe_shared_expert)
self.gguf_writer.add_expert_shared_count(int(moe_shared_expert[0]))

rope_scaling = hparams.get("rope_scaling", {})
if rope_scaling.get("type") == "dynamic":
alpha = rope_scaling.get("alpha", 1000)
base = hparams.get("rope_theta", 10000.0)
hidden_size = hparams.get("hidden_size")
num_attention_heads = hparams.get("num_attention_heads")
max_position_embeddings = self.hparams.get("max_position_embeddings")
if None not in (hidden_size, num_attention_heads, max_position_embeddings):
try:
dim = int(hidden_size) // int(num_attention_heads)
except (TypeError, ValueError):
dim = None
if dim is not None and dim > 2:
scaled_base = base * (alpha ** (dim / (dim - 2)))
self.gguf_writer.add_rope_freq_base(scaled_base)
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_rope_scaling_factor(1)
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024)
self.gguf_writer.add_context_length(256 * 1024)
assert alpha == 1000 and base == 10000.0 and dim == 128 and max_position_embeddings in [32 * 1024, 256 * 1024], \
"Megrez dynamic RoPE scaling assumptions changed, please update the logic or context length manually"

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name == "lm_head.weight":
if self.hparams.get("tie_word_embeddings", False):
logger.info("Skipping tied output layer 'lm_head.weight'")
return []

if name.find("mlp.experts") != -1:
n_experts = self.hparams.get("num_experts")
if n_experts is None or bid is None:
return []
n_experts = int(n_experts)
bid = int(bid)

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))

return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("HunYuanMoEV1ForCausalLM")
class HunYuanMoEModel(TextModel):
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
Expand All @@ -8971,7 +9120,7 @@
# 2. Reverse-engineer the merges list from mergeable_ranks
merges = []
vocab = {}
mergeable_ranks = tokenizer.mergeable_ranks

Check failure on line 9123 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Argument of type "Any | None" cannot be assigned to parameter "x" of type "ConvertibleToInt" in function "__new__"   Type "Any | None" is not assignable to type "ConvertibleToInt"     Type "None" is not assignable to type "ConvertibleToInt"       "None" is not assignable to "str"       "None" is incompatible with protocol "Buffer"         "__buffer__" is not present       "None" is incompatible with protocol "SupportsInt"         "__int__" is not present       "None" is incompatible with protocol "SupportsIndex" ... (reportArgumentType)

Check failure on line 9123 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Argument of type "Any | None" cannot be assigned to parameter "x" of type "ConvertibleToInt" in function "__new__"   Type "Any | None" is not assignable to type "ConvertibleToInt"     Type "None" is not assignable to type "ConvertibleToInt"       "None" is not assignable to "str"       "None" is incompatible with protocol "Buffer"         "__buffer__" is not present       "None" is incompatible with protocol "SupportsInt"         "__int__" is not present       "None" is incompatible with protocol "SupportsIndex" ... (reportArgumentType)
for token, rank in mergeable_ranks.items():
vocab[QwenModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
Expand All @@ -8984,16 +9133,19 @@
vocab_size = self.hparams["vocab_size"]
assert tokenizer.vocab_size == vocab_size
special_tokens = tokenizer.special_tokens
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items() if id_ is not None}
tokens: list[str] = []
toktypes: list[int] = []
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
tokens.append(f"[PAD{i}")
toktypes.append(gguf.TokenType.UNUSED)
else:
token = reverse_vocab[i]
tokens.append(token)
if token is None:
tokens.append(f"[PAD{i}")
else:
tokens.append(str(token))
if i in special_tokens.values():
toktypes.append(gguf.TokenType.CONTROL)
else:
Expand Down
2 changes: 2 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ class MODEL_ARCH(IntEnum):
COGVLM = auto()
MINIMAXM2 = auto()
PANGU_EMBED = auto()
MEGREZ_MOE = auto()


class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -795,6 +796,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.MINIMAXM2: "minimax-m2",
MODEL_ARCH.COGVLM: "cogvlm",
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
MODEL_ARCH.MEGREZ_MOE: "megrez-moe",
}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ add_library(llama
models/mamba.cpp
models/minicpm3.cpp
models/minimax-m2.cpp
models/megrez-moe.cpp
models/mpt.cpp
models/nemotron-h.cpp
models/nemotron.cpp
Expand Down
26 changes: 26 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINIMAX_M2, "minimax-m2" },
{ LLM_ARCH_COGVLM, "cogvlm" },
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_MEGREZ_MOE, "megrez-moe" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -2378,6 +2379,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_MEGREZ_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_PANGU_EMBED,
{
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ enum llm_arch {
LLM_ARCH_MINIMAX_M2,
LLM_ARCH_COGVLM,
LLM_ARCH_PANGU_EMBED,
LLM_ARCH_MEGREZ_MOE,
LLM_ARCH_UNKNOWN,
};

Expand Down
16 changes: 15 additions & 1 deletion src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1382,7 +1382,21 @@ void llama_context::output_reorder() {
//

uint32_t llama_context::graph_max_nodes() const {
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
uint32_t base_nodes = std::max<uint32_t>(1024u, 8u*model.n_tensors());

// Megrez-MoE creates many intermediate tensors in build_mergez_moe_ffn for each layer:
// - sigmoid, add (bias), reshape (3x), get_rows, sum_rows, div, view_2d, mul_mat (per expert)
// - ggml_top_k internally calls ggml_argsort + ggml_view_4d (2 more tensors per layer)
// Each MoE layer needs ~30-40 intermediate tensors during graph construction
// With 30 MoE layers, this adds significant overhead to the graph (30 layers * 35 tensors = ~1050)
// During warmup, the graph is built 3 times with different batch sizes
if (model.arch == LLM_ARCH_MEGREZ_MOE) {
// Add substantial overhead: ~35 intermediate tensors per MoE layer * 30 layers = ~1050 nodes
// Double it to 4096 for safety margin during warmup's triple graph construction
base_nodes += 4096;
}

return base_nodes;
}

llm_graph_result * llama_context::get_gf_res_reserve() const {
Expand Down
78 changes: 78 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2180,12 +2180,26 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case LLM_ARCH_PANGU_EMBED:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

switch (hparams.n_layer) {
case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1
case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_MEGREZ_MOE:
{
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);

switch (hparams.n_layer) {
case 31: type = LLM_TYPE_7B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
default: throw std::runtime_error("unsupported model architecture");
}

Expand Down Expand Up @@ -3338,6 +3352,65 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
}
} break;
case LLM_ARCH_MEGREZ_MOE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);

// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);

for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];

layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);

layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);

// Layer 0 is dense, layers 1-30 are MoE
if (i == 0) {
// Dense layer
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
} else {
// All MoE layers (1-30) have these
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, i), {n_expert}, 0);

if (n_expert == 0) {
throw std::runtime_error("n_expert must be > 0 for MEGREZ_MOE");
}
if (n_expert_used == 0) {
throw std::runtime_error("n_expert_used must be > 0 for MEGREZ_MOE");
}

// All MoE layers have shared expert
const int64_t n_ff_shexp = hparams.n_ff_shexp;
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);

// Only layers 1, 4, 7, 10, 13, 16, 19, 22, 25, 28 have actual expert tensors
// Pattern: (i-1) % 3 == 0 for i > 0
if ((i - 1) % 3 == 0) {
// MoE branch - use the expert-specific FF size from hparams
const int64_t n_ff_exp = hparams.n_ff_exp;

layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
}
// Note: layers that share experts (2, 3, 5, 6, etc.) only have gate_inp and shared expert
// They will reference the regular experts from their corresponding "full" layer during inference
}
}
} break;
case LLM_ARCH_QWEN3:
case LLM_ARCH_QWEN3VL:
{
Expand Down Expand Up @@ -7178,6 +7251,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_jais>(*this, params);
} break;
case LLM_ARCH_MEGREZ_MOE:
{
llm = std::make_unique<llm_build_megrez_moe>(*this, params);
} break;
case LLM_ARCH_NEMOTRON:
{
llm = std::make_unique<llm_build_nemotron>(*this, params);
Expand Down Expand Up @@ -7518,6 +7595,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GPTNEOX:
case LLM_ARCH_CODESHELL:
case LLM_ARCH_ORION:
case LLM_ARCH_MEGREZ_MOE:
case LLM_ARCH_NEMOTRON:
case LLM_ARCH_EXAONE:
case LLM_ARCH_EXAONE4:
Expand Down
Loading
Loading