Skip to content

Commit 7d05da8

Browse files
committed
updates
1 parent 02bf517 commit 7d05da8

File tree

2 files changed

+3
-47
lines changed

2 files changed

+3
-47
lines changed

examples/llama/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def forward(self, x, ffn_norm):
405405
]
406406
elif seqlen == 128:
407407
schedule = [
408-
[128, [128, 64], 16480],
408+
[128, [128, 64], 12384],
409409
]
410410
else:
411411
raise ValueError(f"Unsupported seqlen {seqlen}")
@@ -471,7 +471,7 @@ def forward(self, x, ffn_norm):
471471
sram = 24672
472472
elif seqlen == 128:
473473
tile = [128, 64]
474-
sram = 16480
474+
sram = 12384
475475
else:
476476
raise ValueError(f"Unsupported seqlen {seqlen}")
477477

@@ -645,7 +645,7 @@ def softmax(scores):
645645
sram = 24672
646646
elif seqlen == 128:
647647
tile = [128, 64]
648-
sram = 16480
648+
sram = 12384
649649
else:
650650
raise ValueError(f"Unsupported seqlen {seqlen}")
651651

examples/llama/model_test.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -363,49 +363,6 @@ def test_column_parallel_linear(
363363
)
364364

365365

366-
def test_attention(
367-
args: ModelArgs,
368-
batch_size: int,
369-
seq_len: int,
370-
dtype: np.dtype,
371-
rank: int = 0,
372-
world_size: int = 1,
373-
):
374-
#
375-
freqs_cis = precompute_freqs_cis(
376-
args.dim // args.n_heads, args.max_seq_len * 2
377-
)[0:seq_len]
378-
379-
freqs_cis_ark = freqs_cis.astype(np.complex64)
380-
freqs_cis_ark = (
381-
np.stack([freqs_cis_ark.real, freqs_cis_ark.imag], axis=-1)
382-
.astype(dtype)
383-
.reshape(1, seq_len, 1, args.dim // args.n_heads)
384-
)
385-
386-
seed = 1695878986 # int(time.time())
387-
print(f"seed: {seed}")
388-
np.random.seed(seed)
389-
feature = np.random.uniform(
390-
low=-0.1, high=0.1, size=(batch_size, seq_len, args.dim)
391-
).astype(dtype)
392-
393-
test_module(
394-
module_class_ark=model_ark.Attention,
395-
module_args_ark=[
396-
args,
397-
ark.DataType.from_numpy(dtype),
398-
rank,
399-
world_size,
400-
],
401-
inputs_ark=[feature, 0, freqs_cis_ark, None],
402-
module_class_pt=model_pt.Attention,
403-
module_args_pt=[args],
404-
inputs_pt=[feature.astype(dtype), 0, freqs_cis, None],
405-
module_name_prefix="layers.0.attention",
406-
)
407-
408-
409366
def test_transformer(
410367
args: ModelArgs,
411368
batch_size: int,
@@ -472,7 +429,6 @@ def test(args, batch_size, seq_len, dtype, rank, world_size):
472429
# test_rmsnorm(args, batch_size, seq_len, dtype)
473430
# test_row_parallel_linear(args, batch_size, seq_len, dtype, rank, world_size)
474431
# test_column_parallel_linear(args, batch_size, seq_len, dtype, rank, world_size)
475-
# test_attention(args, batch_size, seq_len, dtype, rank, world_size)
476432
test_transformer(args, batch_size, seq_len, dtype, rank, world_size)
477433

478434

0 commit comments

Comments
 (0)