Skip to content

LLaMA 13B and 70B fail on CPU with BF16 #778

@Yufeng98

Description

@Yufeng98

LLaMA 7B runs well on CPU with both BF16 and FP32. But LLaMA 13B and 70B only work on CPU with FP32.

The error for LLaMA 13B and 70B with BF16 comes from embedding and the RuntimeError is Invalid scalar type.

Traceback (most recent call last):                                                                                                                                                                                  
  File "/data1/llama-cpu/example_text_completion.py", line 70, in <module>                                                                                                                                 
    fire.Fire(main)                                                                                                                                                                                                 
  File "/home/anaconda3/lib/python3.9/site-packages/fire/core.py", line 141, in Fire                                                                                                                       
    component_trace = _Fire(component, args, parsed_flag_args, context, name)                                                                                                                                       
  File "/home/anaconda3/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire                                                                                                                      
    component, remaining_args = _CallAndUpdateTrace(                                                                                                                                                                
  File "/home/anaconda3/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace                                                                                                        
    component = fn(*varargs, **kwargs)                                                                                                                                                                              
  File "/data1/llama-cpu/example_text_completion.py", line 57, in main                                                                                                                                     
    results = generator.text_completion(                                                                                                                                                                            
  File "/data1/llama-cpu/llama/generation.py", line 264, in text_completion                                                                                                                                
    generation_tokens, generation_logprobs = self.generate(                                                                                                                                                         
  File "/data1/llama-cpu/llama/generation.py", line 181, in generate                                                                                                                                       
    logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)                                                                                                                                              
  File "/data1/llama-cpu/llama/model.py", line 471, in forward                                                                                                                                             
    h = self.tok_embeddings(tokens)                                                                                                                                                                                 
  File "/home/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                                                                                                  
    return forward_call(*input, **kwargs)
  File "/home/anaconda3/lib/python3.9/site-packages/fairscale/nn/model_parallel/layers.py", line 214, in forward                                                                                           
    output = gather_from_model_parallel_region(output_parallel)                              
  File "/home/anaconda3/lib/python3.9/site-packages/fairscale/nn/model_parallel/mappings.py", line 156, in gather_from_model_parallel_region
    return _GatherFromModelParallelRegion.apply(input_)                                                   
  File "/home/anaconda3/lib/python3.9/site-packages/fairscale/nn/model_parallel/mappings.py", line 131, in forward                                                                                         
    return _gather(input_)
  File "/home/anaconda3/lib/python3.9/site-packages/fairscale/nn/model_parallel/mappings.py", line 82, in _gather                                                                                          
    torch.distributed.all_gather(tensor_list, input_, group=group)                                        
  File "/home/anaconda3/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2075, in all_gather
    work.wait()                                                                                           
RuntimeError: Invalid scalar type 

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions