Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -926,52 +926,6 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class GraniteFlashAttentionKwargs(TypedDict, total=False):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like these just got moved during the regeneration. I'm not sure if should be included (to enforce consistency with the generation script) or excluded (to minimize the size of the change).

"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.

Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""

cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor


@use_kernel_forward_from_hub("RMSNorm")
class GraniteMoeHybridRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
GraniteMoeHybridRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class GraniteMoeHybridParallelExperts(nn.Module):
def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
"""
Expand Down Expand Up @@ -1113,13 +1067,61 @@ def forward(self, layer_input):
return layer_output


class GraniteFlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.

Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""

cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor


@use_kernel_forward_from_hub("RMSNorm")
class GraniteMoeHybridRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
GraniteMoeHybridRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
# Either attention or mamba will be initialized, depending on the layer type.
self.self_attn = None
self.block_sparse_moe = GraniteMoeHybridMoE(config)

# Allow non-MoE (dense)
self.block_sparse_moe = GraniteMoeHybridMoE(config) if config.num_local_experts > 0 else None
self.input_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
GraniteMoeSharedForCausalLM,
GraniteMoeSharedMLP,
GraniteMoeSharedModel,
GraniteMoeSharedMoE,
GraniteMoeSharedPreTrainedModel,
eager_attention_forward,
)
Expand Down Expand Up @@ -107,6 +108,10 @@ class GraniteMoeHybridRotaryEmbedding(Gemma2RotaryEmbedding):
pass


class GraniteMoeHybridMoE(GraniteMoeSharedMoE):
pass


class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
super().__init__(config, layer_idx)
Expand All @@ -121,6 +126,9 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
self.self_attn = GraniteMoeHybridAttention(config, layer_idx)
self.layer_type = config.layers_block_type[layer_idx]

# Allow non-MoE (dense)
self.block_sparse_moe = GraniteMoeHybridMoE(config) if config.num_local_experts > 0 else None

# Accept 0 experts: skip MoE if num_local_experts == 0
self.has_experts = getattr(config, "num_local_experts", 0) > 0

Expand Down