Skip to content

Conversation

@tamarPal
Copy link
Contributor

@tamarPal tamarPal commented Nov 6, 2025

Summary

This PR adds full support for the Megrez-MoE (Mixture of Experts) architecture to llama.cpp, enabling inference on Megrez2-3x7B models and similar MoE variants.

Architecture Details

Megrez-MoE is a Mixture of Experts architecture with:

  • 64 experts with top-6 selection per layer
  • 4 shared experts across all tokens
  • Sigmoid + bias gating mechanism (different from standard softmax gating)
  • 30 MoE layers with 2048 embedding dimension
  • Context length up to 163,840 tokens

Changes Made

1. Architecture Registration

  • Added LLM_ARCH_MEGREZ_MOE architecture enum
  • Registered MoE-specific hyperparameters (expert counts, FFN dimensions)
  • Added tensor mapping for 64 experts × 30 layers

2. MoE FFN Implementation

Implemented build_mergez_moe_ffn() with:

  • Sigmoid gating with bias (unique to Megrez-MoE)
  • Top-K expert selection using ggml_top_k()
  • Shared experts processing for all tokens
  • Per-expert feed-forward computation

3. Model Loading

  • Added llm_build_megrez_moe class
  • Implemented hyperparameter loading (expert_count, expert_used_count, etc.)
  • Implemented tensor loading for all expert weights

4. Graph Memory Fix

  • Problem: Warmup crashed with "not enough space in context's memory pool"
  • Root Cause: MoE FFN creates ~35 intermediate tensors per layer (sigmoid, reshape, top_k, etc.)
  • Solution: Added 4096 node overhead to graph_max_nodes() for Megrez-MoE (30 layers × 35 tensors ≈ 1050 nodes, doubled for safety)

Testing

  • All 39 existing tests pass
  • No regression in other architectures (verified with Gemma-2)
  • Warmup works without crashes
  • Output generation verified (up to 200 tokens)
  • Performance: ~17 tokens/second on CPU

Comparison

# Without this PR:
$ ./build/bin/llama-cli -m Megrez2-3x7B.gguf -p "Test"
error: unknown model architecture: 'megrez-moe'

# With this PR:
$ ./build/bin/llama-cli -m Megrez2-3x7B.gguf -p "Test"
Works perfectly with warmup enabled

@tamarPal
Copy link
Contributor Author

tamarPal commented Nov 6, 2025

@pwilkin @filippide
Added full support for Megrez-MoE architecture
Tested: All 39 tests pass, no regression on other models, ~17 tok/s on CPU.

Comment on lines 1086 to 1089
// For Megrez: sigmoid THEN add bias (not the other way around!)
normalized_logits = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
cb(normalized_logits, "ffn_moe_logits_normalize", il);
probs = ggml_add(ctx0, normalized_logits, exp_probs_b); // Add bias AFTER sigmoid
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about ffn_moe_probs_biased in build_moe_ffn ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question!
The key difference is that Megrez-MoE uses sigmoid + bias instead of softmax, with a different computation order.
In standard build_moe_ffn, the biased probabilities are: probs_biased = softmax(logits + bias) with bias BEFORE softmax.
But Megrez-MoE requires: probs = sigmoid(logits) + bias with bias AFTER sigmoid.
The bias is added after the activation function (not before), and we use sigmoid (per-expert independent scores) instead of softmax (normalized distribution). This is specific to the Megrez-MoE architecture design.
That's why I couldn't reuse ffn_moe_probs_biased directly - both the activation function and computation order are fundamentally different.
Thank's!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what ffn_moe_probs_biased does though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In standard build_moe_ffn, the biased probabilities are: probs_biased = softmax(logits + bias) with bias BEFORE softmax.

This is ffn_moe_logits_biased, not ffn_moe_probs_biased

@CISC
Copy link
Collaborator

CISC commented Nov 6, 2025

Seems this has been badly merged.

tamarPal added 4 commits November 6, 2025 23:45
Implements complete support for Megrez-MoE (Mixture of Experts) models:

- Add LLM_ARCH_MEGREZ_MOE architecture enum and mappings
- Implement build_mergez_moe_ffn() with sigmoid+bias gating
- Add llm_build_megrez_moe class for full model graph construction
- Support 31-layer architecture (layer 0: dense FFN, layers 1-30: MoE)
- Implement expert sharing pattern with 64 experts, 6 used per token, 4 shared
- Load all model hyperparameters and 372 tensors correctly
- Configure NEOX RoPE type for proper positional encoding

Tested with Megrez2-3x7B-A3B_Q4_K_M.gguf model.
All 39 llama.cpp tests pass successfully.
Output verified to match infinigence/llama.cpp reference implementation.

Note: Use --no-warmup flag to avoid warmup memory allocation issue.
Megrez-MoE creates many intermediate tensors during MoE FFN construction:
- sigmoid, add, reshape (3x), get_rows, sum_rows, div, view_2d, mul_mat operations
- ggml_top_k internally calls ggml_argsort + ggml_view_4d (2 more tensors per layer)
- Each of 30 MoE layers creates ~35 intermediate tensors during graph construction

During warmup, the graph is built 3 times with different batch sizes, requiring
sufficient memory pool space for all intermediate tensors.

Add 4096 node overhead for LLM_ARCH_MEGREZ_MOE to accommodate these intermediate
tensors (30 layers × 35 tensors/layer ≈ 1050 nodes, doubled for safety margin).

This fixes the 'not enough space in the context's memory pool' error during warmup,
allowing Megrez-MoE to work without the --no-warmup flag.

Tested:
- All 39 tests pass
- Megrez-MoE works with warmup enabled (no crashes)
- Other models (e.g., Gemma-2) are unaffected
- Verified with outputs up to 100 tokens
- Move llm_build_megrez_moe from llama-model.cpp to src/models/megrez-moe.cpp
- Add declaration to src/models/models.h
- Update CMakeLists.txt to include megrez-moe.cpp in build
- Resolve merge conflicts in llama-arch.cpp and llama-model.cpp
- Fix PANGU_EMBED case statement closing braces

The model loads successfully, all tests pass (40/40), and inference works correctly.
…oe_ffn

- Remove custom build_mergez_moe_ffn implementation (100+ lines)
- Use existing build_moe_ffn with LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID
- Pre-compute gate logits from pre_gate_hidden (Megrez-MoE's unique gating)
- Pass pre-computed logits via probs_in parameter
- Maintain exact same behavior and output quality

This addresses review feedback to reuse existing MoE infrastructure
instead of duplicating code. The sigmoid gating + bias after activation
is already supported by build_moe_ffn.
@github-actions github-actions bot added the model Model specific label Nov 6, 2025
@tamarPal
Copy link
Contributor Author

tamarPal commented Nov 7, 2025

@ngxson @CISC
Thank you for the review!

Fixed:

Merge conflicts - Rebased on latest master,
Removed duplicate code - Now using standard [build_moe_ffn]
Model output remains identical.
Thank's!

@CISC
Copy link
Collaborator

CISC commented Nov 7, 2025

Merge conflicts - Rebased on latest master

Better, but there's still some fixes to be made, check the diffs in Files changed, also the conversion code seem to have gone missing.

- Restore PANGU_EMBED and COGVLM tensor mappings in llama-arch.cpp
- Remove extra blank line in llama-context.cpp
@tamarPal
Copy link
Contributor Author

tamarPal commented Nov 9, 2025

Hi @CISC,

Thanks for the review! I've fixed the merge issues in commit.
The branch now merges cleanly with master and all tests pass.
Thank's!

@CISC
Copy link
Collaborator

CISC commented Nov 9, 2025

Thanks for the review! I've fixed the merge issues in commit.

Still missing conversion code. :)

@tamarPal
Copy link
Contributor Author

tamarPal commented Nov 9, 2025

Hi @CISC,

Thank you for your feedback!
I'm currently working on adding the Megrez-MoE conversion code to GGUF, but this will take some time to implement and test properly.
In the meantime, is it possible to merge the current PR?
The C++ code already provides full inference support for Megrez-MoE models in GGUF format,

Thank you for your consideration!

@CISC
Copy link
Collaborator

CISC commented Nov 9, 2025

Thank you for your feedback! I'm currently working on adding the Megrez-MoE conversion code to GGUF, but this will take some time to implement and test properly. In the meantime, is it possible to merge the current PR? The C++ code already provides full inference support for Megrez-MoE models in GGUF format,

No, conversion code must be included in the PR.

@tamarPal
Copy link
Contributor Author

tamarPal commented Nov 9, 2025

Thank you for your feedback! I'm currently working on adding the Megrez-MoE conversion code to GGUF, but this will take some time to implement and test properly. In the meantime, is it possible to merge the current PR? The C++ code already provides full inference support for Megrez-MoE models in GGUF format,

No, conversion code must be included in the PR.

@tamarPal tamarPal closed this Nov 9, 2025
@tamarPal tamarPal reopened this Nov 9, 2025
// Use standard build_moe_ffn but with pre-computed gate logits
ggml_tensor * moe_out = build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[((il - 1) / (3) * (3)) + 1].ffn_up_exps,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is ((il - 1) / (3) * (3)) + 1 different from il ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because il is an integer, the division /3 performs integer division — so this expression effectively groups layers in blocks of 3 and returns the first layer index of the block.

@github-actions github-actions bot added the python python script changes label Nov 9, 2025
@CISC
Copy link
Collaborator

CISC commented Nov 9, 2025

You are making what appears to be complete random changes that have nothing to do with this model, you can open this PR again once you've sorted it out properly.

@CISC CISC closed this Nov 9, 2025
@tamarPal
Copy link
Contributor Author

tamarPal commented Nov 9, 2025

Got it, thanks for the feedback. I’ll clean up the unrelated changes and make sure the PR only includes the model-specific parts for MegrezMoE. Once it’s properly scoped, I’ll reopen it.

@tamarPal
Copy link
Contributor Author

Hi @CISC,

Fixed the conversion code and tested with a real model (Infinigence/Megrez2-3x7B-A3B).

  • Conversion works
  • Model loads and runs correctly
  • All tests passing
  • Changes only affect Megrez-MoE

Could you please reopen this PR?

Thanks!

@CISC
Copy link
Collaborator

CISC commented Nov 10, 2025

Could you please reopen this PR?

Sorry, GitHub gives this error, guess you'll have to resubmit after all:

The feature/megrez-moe branch was force-pushed or recreated.

@tamarPal
Copy link
Contributor Author

Could you please reopen this PR?

Sorry, GitHub gives this error, guess you'll have to resubmit after all:

Thanks for trying to reopen!

Would you prefer:

I open a new PR with the fixed code?
Or is there another way you'd like me to handle this?
The fixes are ready and tested - conversion works with real models, all tests pass, and the changes are minimal (only affects Megrez-MoE).

Thanks!

@CISC
Copy link
Collaborator

CISC commented Nov 10, 2025

I open a new PR with the fixed code?

Yes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants