@@ -363,49 +363,6 @@ def test_column_parallel_linear(
363363 )
364364
365365
366- def test_attention (
367- args : ModelArgs ,
368- batch_size : int ,
369- seq_len : int ,
370- dtype : np .dtype ,
371- rank : int = 0 ,
372- world_size : int = 1 ,
373- ):
374- #
375- freqs_cis = precompute_freqs_cis (
376- args .dim // args .n_heads , args .max_seq_len * 2
377- )[0 :seq_len ]
378-
379- freqs_cis_ark = freqs_cis .astype (np .complex64 )
380- freqs_cis_ark = (
381- np .stack ([freqs_cis_ark .real , freqs_cis_ark .imag ], axis = - 1 )
382- .astype (dtype )
383- .reshape (1 , seq_len , 1 , args .dim // args .n_heads )
384- )
385-
386- seed = 1695878986 # int(time.time())
387- print (f"seed: { seed } " )
388- np .random .seed (seed )
389- feature = np .random .uniform (
390- low = - 0.1 , high = 0.1 , size = (batch_size , seq_len , args .dim )
391- ).astype (dtype )
392-
393- test_module (
394- module_class_ark = model_ark .Attention ,
395- module_args_ark = [
396- args ,
397- ark .DataType .from_numpy (dtype ),
398- rank ,
399- world_size ,
400- ],
401- inputs_ark = [feature , 0 , freqs_cis_ark , None ],
402- module_class_pt = model_pt .Attention ,
403- module_args_pt = [args ],
404- inputs_pt = [feature .astype (dtype ), 0 , freqs_cis , None ],
405- module_name_prefix = "layers.0.attention" ,
406- )
407-
408-
409366def test_transformer (
410367 args : ModelArgs ,
411368 batch_size : int ,
@@ -472,7 +429,6 @@ def test(args, batch_size, seq_len, dtype, rank, world_size):
472429 # test_rmsnorm(args, batch_size, seq_len, dtype)
473430 # test_row_parallel_linear(args, batch_size, seq_len, dtype, rank, world_size)
474431 # test_column_parallel_linear(args, batch_size, seq_len, dtype, rank, world_size)
475- # test_attention(args, batch_size, seq_len, dtype, rank, world_size)
476432 test_transformer (args , batch_size , seq_len , dtype , rank , world_size )
477433
478434
0 commit comments