Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions models/baichuan/7b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"emb_size": 4096,
"feedforward_size": 11008,
"hidden_size": 4096,
"hidden_act": "silu",
"heads_num": 32,
"layers_num": 32,
"max_seq_length": 4096,
"dropout": 0.0,
"data_processor": "lm",
"embedding": ["word"],
"remove_transformer_bias": true,
"remove_embedding_layernorm": true,
"rotary_position_embedding": true,
"encoder": "transformer",
"feed_forward": "gated",
"mask": "causal",
"layernorm_positioning": "pre",
"layernorm": "rms",
"target": ["lm"],
"normHead": true,
"z_loss": true,
"use_xformers": true,
"baichuan_RoPE": true
}
7 changes: 7 additions & 0 deletions models/baichuan_special_tokens_map.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"pad_token": "<unk>",
"unk_token": "<unk>",
"cls_token": "<s>",
"sep_token": "</s>",
"mask_token": "<mask>"
}
69 changes: 69 additions & 0 deletions scripts/convert_baichuan_from_huggingface_to_tencentpretrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import argparse
import collections
import torch
import os
import json
import torch.nn as nn

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--input_model_path", type=str, default="models/baichuan-7b/",
help=".")
parser.add_argument("--output_model_path", type=str, default="models/baichuan-7b.bin",
help=".")
parser.add_argument("--type", choices=["7B"], default="7B")

args = parser.parse_args()

model_config = {"7B": [32, 4096, 32]}

layers_num, dim, n_heads = model_config[args.type]

files = os.listdir(args.input_model_path)
model_files = [f for f in files if f[-4:] == ".bin"]
input_models = {f: torch.load(os.path.join(args.input_model_path, f), map_location="cpu") for f in model_files}

with open(os.path.join(args.input_model_path, "pytorch_model.bin.index.json")) as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]


output_model = collections.OrderedDict()

def get_weight_from_name(layer_name):
return input_models[weight_map[layer_name]][layer_name]

def unpermute(w):
return w.reshape(n_heads, 2, dim // n_heads // 2, dim).transpose(2, 1).reshape(dim, dim)

output_model["embedding.word.embedding.weight"] = get_weight_from_name("model.embed_tokens.weight")

for i in range(layers_num):
# W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
W_pack = get_weight_from_name("model.layers." + str(i) + ".self_attn.W_pack.weight")
output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.0.weight"] = \
W_pack[0: dim]
output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.1.weight"] = \
W_pack[dim: dim * 2]
output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers.2.weight"] = \
W_pack[dim * 2: dim * 3]

output_model["encoder.transformer." + str(i) + ".self_attn.final_linear.weight"] = \
get_weight_from_name("model.layers." + str(i) + ".self_attn.o_proj.weight")

output_model["encoder.transformer." + str(i) + ".layer_norm_1.weight"] = \
get_weight_from_name("model.layers." + str(i) + ".input_layernorm.weight")

output_model["encoder.transformer." + str(i) + ".feed_forward.linear_gate.weight"] = \
get_weight_from_name("model.layers." + str(i) + ".mlp.gate_proj.weight")
output_model["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"] = \
get_weight_from_name("model.layers." + str(i) + ".mlp.up_proj.weight")
output_model["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"] = \
get_weight_from_name("model.layers." + str(i) + ".mlp.down_proj.weight")

output_model["encoder.transformer." + str(i) + ".layer_norm_2.weight"] = \
get_weight_from_name("model.layers." + str(i) + ".post_attention_layernorm.weight")

output_model["encoder.layer_norm.weight"] = get_weight_from_name("model.norm.weight")
output_model["target.lm.output_layer.weight"] = nn.functional.normalize(get_weight_from_name("lm_head.weight").float())

torch.save(output_model, args.output_model_path)
80 changes: 52 additions & 28 deletions tencentpretrain/layers/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
import torch
import torch.nn as nn
from tencentpretrain import mpu
from tencentpretrain.utils.rope import apply_rotary_emb
from tencentpretrain.utils.rope import apply_rotary_emb, apply_rotary_emb_baichuan
from tencentpretrain.utils.lora import LoraLinear

# add xformers
try:
from xformers import ops as xops
except ImportError:
xops = None


def repeat_kv(x: torch.Tensor, repeat_num: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
Expand All @@ -27,8 +33,10 @@ class MultiHeadedAttention(nn.Module):
"""

def __init__(self, hidden_size, heads_num, attention_head_size, local_kv_heads_num, dropout, has_bias=True, with_scale=True,
lora_params=None, layer_number=None):
lora_params=None, layer_number=None, use_xformers=False, baichuan_RoPE=False):
super(MultiHeadedAttention, self).__init__()
self.baichuan_RoPE = baichuan_RoPE
self.use_xformers = use_xformers
self.heads_num = heads_num
self.per_head_size = attention_head_size
self.with_scale = with_scale
Expand Down Expand Up @@ -105,40 +113,56 @@ def unshape(x):


if freqs_cis is not None:
query, key = apply_rotary_emb(query.transpose(1,2), key.transpose(1,2), freqs_cis=freqs_cis)
if self.baichuan_RoPE:
query, key = apply_rotary_emb_baichuan(query, key, freqs_cis=freqs_cis)
else:
query, key = apply_rotary_emb(query.transpose(1,2), key.transpose(1,2), freqs_cis=freqs_cis)


scores = torch.matmul(query, key.transpose(-2, -1))
prev_attn_out = None
# xformers attention
if xops is not None and self.use_xformers and self.training:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value= value.transpose(1, 2)
output = xops.memory_efficient_attention(
query, key, value, attn_bias=xops.LowerTriangularMask()
)
output = unshape(output)
else:
scores = torch.matmul(query, key.transpose(-2, -1))

if position_bias is not None:
scores = scores + position_bias
if position_bias is not None:
scores = scores + position_bias

if self.with_scale:
if self.layer_number is not None:
scores = scores * (1.0 / self.norm_factor)
else:
scores = scores / math.sqrt(float(per_head_size))
if alibi is not None:
scores = scores.reshape((-1, scores.shape[-2], scores.shape[-1]))
scores += (1.0 / self.layer_number) * alibi
scores = scores.view(-1, heads_num, scores.shape[-2], scores.shape[-1])
if self.with_scale:
if self.layer_number is not None:
scores = scores * (1.0 / self.norm_factor)
else:
scores = scores / math.sqrt(float(per_head_size))
if alibi is not None:
scores = scores.reshape((-1, scores.shape[-2], scores.shape[-1]))
scores += (1.0 / self.layer_number) * alibi
scores = scores.view(-1, heads_num, scores.shape[-2], scores.shape[-1])

scores = scores + mask.type_as(scores)
scores = scores + mask.type_as(scores)

# scaled softmax
if self.layer_number is not None:
scores = (scores * self.layer_number) + mask
scores = torch.max(scores, torch.tensor(-10000))
# scaled softmax
if self.layer_number is not None:
scores = (scores * self.layer_number) + mask
scores = torch.max(scores, torch.tensor(-10000))

prev_attn_out = None
prev_attn_out = None

if has_residual_attention:
if prev_attn is not None:
scores += prev_attn
prev_attn_out = scores
if has_residual_attention:
if prev_attn is not None:
scores += prev_attn
prev_attn_out = scores

probs = nn.Softmax(dim=-1)(scores)
probs = self.dropout(probs)
output = unshape(torch.matmul(probs, value))

probs = nn.Softmax(dim=-1)(scores)
probs = self.dropout(probs)
output = unshape(torch.matmul(probs, value))
output = self.final_linear(output)
return output, prev_attn_out

Expand Down
3 changes: 2 additions & 1 deletion tencentpretrain/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def __init__(self, args, layer_number=None):

self.self_attn = MultiHeadedAttention(
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, has_bias=has_bias,
with_scale = with_scale, lora_params=lora_params, layer_number=layer_number
with_scale = with_scale, lora_params=lora_params, layer_number=layer_number, use_xformers=args.use_xformers,
baichuan_RoPE = args.baichuan_RoPE
)
self.dropout_1 = nn.Dropout(args.dropout)

Expand Down
7 changes: 7 additions & 0 deletions tencentpretrain/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def model_opts(parser):
help="whether use alibi position embedding.")
parser.add_argument("--layer_number_scale", action="store_true",
help="whether use layer number scaling.")
parser.add_argument("--use_xformers", action="store_true",
help="whether use xformers")
parser.add_argument("--baichuan_RoPE", action="store_true",
help="whether use baichuan_RoPE")

vision_opts(parser)
audio_opts(parser)
Expand Down Expand Up @@ -111,6 +115,9 @@ def optimization_opts(parser):
"constant", "constant_with_warmup", "inverse_sqrt", "tri_stage"],
default="linear", help="Scheduler type.")

parser.add_argument('--normHead', action='store_true')
parser.add_argument('--z_loss', action='store_true')


def training_opts(parser):
parser.add_argument("--batch_size", type=int, default=32,
Expand Down
21 changes: 20 additions & 1 deletion tencentpretrain/targets/lm_target.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch
import torch.nn as nn
import math

from tencentpretrain import mpu
from tencentpretrain.utils.constants import *
from tencentpretrain.utils.NormHead import NormHead


class LmTarget(nn.Module):
Expand Down Expand Up @@ -39,6 +41,12 @@ def __init__(self, args, vocab_size):
self.softmax = nn.LogSoftmax(dim=-1)
self.criterion = nn.NLLLoss()

# NormHead maxZ-loss
self.normHead = args.normHead
if args.normHead:
self.lm_head = NormHead(args.hidden_size, vocab_size, bias=False)
self.z_loss = args.z_loss

def lm(self, memory_bank, tgt_lm, seg):
# Language modeling (LM) with full softmax prediction.
seg = seg.contiguous().view(-1)
Expand All @@ -52,8 +60,16 @@ def lm(self, memory_bank, tgt_lm, seg):
if tgt_lm is not None:
tgt_lm = tgt_lm.contiguous().view(-1)
tgt_lm = tgt_lm[seg > loss_mask]

if self.normHead:
output = self.lm_head(memory_bank)
else:
output = self.output_layer(memory_bank)

if self.z_loss:
softmax_normalizer = output.max(-1).values ** 2
z_loss = 2 * math.exp(-4) * softmax_normalizer.mean()

output = self.output_layer(memory_bank)
if self.pipeline_model_parallel_size > 1:

return output, loss_mask
Expand Down Expand Up @@ -83,6 +99,9 @@ def lm(self, memory_bank, tgt_lm, seg):
eps_i = self.label_smoothing / (output.size(-1) - 1)
loss = (1.0 - self.label_smoothing - eps_i) * nll_loss + eps_i * smooth_loss

if self.z_loss:
loss += z_loss

return loss

def forward(self, memory_bank, tgt, seg):
Expand Down
22 changes: 22 additions & 0 deletions tencentpretrain/utils/NormHead.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
import math
from torch import nn


class NormHead(nn.Module):
def __init__(self, hidden_size, vocab_size, bias=False):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size)))
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
self.first_flag = True

def forward(self, hidden_states):
if self.training:
norm_weight = nn.functional.normalize(self.weight)
elif self.first_flag:
self.first_flag = False
self.weight = nn.Parameter(nn.functional.normalize(self.weight))
norm_weight = self.weight
else:
norm_weight = self.weight
return nn.functional.linear(hidden_states, norm_weight)
23 changes: 23 additions & 0 deletions tencentpretrain/utils/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,26 @@ def apply_rotary_emb(
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).transpose(1,2), xk_out.type_as(xk).transpose(1,2)


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_emb_baichuan(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
cos_sin = torch.view_as_real(freqs_cis)
cos = cos_sin[:, :, 0]
sin = cos_sin[:, :, 1]
cos = torch.cat((cos, cos), dim=-1).unsqueeze(0).unsqueeze(0)
sin = torch.cat((sin, sin), dim=-1).unsqueeze(0).unsqueeze(0)
q_embed = (xq.float() * cos) + (rotate_half(xq.float()) * sin)
k_embed = (xk.float() * cos) + (rotate_half(xk.float()) * sin)
return q_embed.to(xq.dtype), k_embed.to(xk.dtype)