-
Notifications
You must be signed in to change notification settings - Fork 10
Open
Description
There appears to be a mismatch between the RoPE frequency table initialization and its usage during training in gpt.py.
Current Behavior
In Transformer.__init__():
self.freqs_cis = precompute_freqs_cis_2d(
grid_size,
self.config.dim // self.config.n_head,
self.config.rope_base,
self.cls_token_num + self.condition_token_num # = 1 + 256 = 257
)
The precompute_freqs_cis_2d function creates a frequency table with 257 zero-frequency positions at the beginning, followed by actual RoPE frequencies for the image tokens.
During training in forward():
token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1)
# Shape: (bs, 1 + 255, dim) = (bs, 256, dim)
if self.training:
freqs_cis = self.freqs_cis[:token_embeddings.shape[1]] # Takes first 256 positions
Problem
Since token_embeddings.shape[1] = 256 and the first 257 positions of freqs_cis are zeros, the sliced freqs_cis[:256] contains all zeros. This means the entire training sequence effectively has no positional encoding from RoPE.
Metadata
Metadata
Assignees
Labels
No labels