Upload 2 files
Browse files- yuan_hf_model.py +59 -11
- yuan_hf_model_cpu.py +60 -12
yuan_hf_model.py
CHANGED
@@ -25,7 +25,6 @@ import torch
|
|
25 |
import torch.utils.checkpoint
|
26 |
from torch import nn
|
27 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
28 |
-
from transformers.models.llama.modeling_llama import LlamaRMSNorm,LlamaRotaryEmbedding
|
29 |
from transformers.activations import ACT2FN
|
30 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
31 |
from transformers.modeling_utils import PreTrainedModel
|
@@ -58,9 +57,7 @@ class LocalizedFiltering(torch.nn.Module):
|
|
58 |
|
59 |
self.conv1 = torch.nn.Conv2d(self.embed_dim, self.embed_dim // 2, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
|
60 |
self.conv2 = torch.nn.Conv2d(self.embed_dim // 2, self.embed_dim, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
|
61 |
-
|
62 |
-
#Use the same RMSNorm as llama
|
63 |
-
self.output_layernorm = LlamaRMSNorm(self.embed_dim)
|
64 |
|
65 |
def _train_forward(self, inputs):
|
66 |
inputs = inputs.transpose(0,1)
|
@@ -197,7 +194,61 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|
197 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
198 |
return q_embed, k_embed
|
199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
class YuanMLP(nn.Module):
|
203 |
def __init__(
|
@@ -240,8 +291,7 @@ class YuanAttention(nn.Module):
|
|
240 |
)
|
241 |
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
242 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
243 |
-
|
244 |
-
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
245 |
if self.use_shareqk:
|
246 |
self.qk_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
247 |
self.qk_weight = nn.Parameter(torch.Tensor(2, self.hidden_size))
|
@@ -393,9 +443,8 @@ class YuanDecoderLayer(nn.Module):
|
|
393 |
intermediate_size=config.intermediate_size,
|
394 |
hidden_act=config.hidden_act,
|
395 |
)
|
396 |
-
|
397 |
-
self.
|
398 |
-
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
399 |
|
400 |
def forward(
|
401 |
self,
|
@@ -583,8 +632,7 @@ class YuanModel(YuanPreTrainedModel):
|
|
583 |
self.reset_position_ids = config.reset_position_ids
|
584 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
585 |
self.layers = nn.ModuleList([YuanDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
586 |
-
|
587 |
-
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
588 |
self.gradient_checkpointing = False
|
589 |
# Initialize weights and apply final processing
|
590 |
self.post_init()
|
|
|
25 |
import torch.utils.checkpoint
|
26 |
from torch import nn
|
27 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
28 |
from transformers.activations import ACT2FN
|
29 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
30 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
57 |
|
58 |
self.conv1 = torch.nn.Conv2d(self.embed_dim, self.embed_dim // 2, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
|
59 |
self.conv2 = torch.nn.Conv2d(self.embed_dim // 2, self.embed_dim, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
|
60 |
+
self.output_layernorm = YuanRMSNorm(self.embed_dim)
|
|
|
|
|
61 |
|
62 |
def _train_forward(self, inputs):
|
63 |
inputs = inputs.transpose(0,1)
|
|
|
194 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
195 |
return q_embed, k_embed
|
196 |
|
197 |
+
class YuanRMSNorm(nn.Module):
|
198 |
+
def __init__(self, hidden_size, eps=1e-6):
|
199 |
+
"""
|
200 |
+
YuanRMSNorm is equivalent to LlamaRMSNorm
|
201 |
+
"""
|
202 |
+
super().__init__()
|
203 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
204 |
+
self.variance_epsilon = eps
|
205 |
+
|
206 |
+
def forward(self, hidden_states):
|
207 |
+
input_dtype = hidden_states.dtype
|
208 |
+
hidden_states = hidden_states.to(torch.float32)
|
209 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
210 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
211 |
+
return self.weight * hidden_states.to(input_dtype)
|
212 |
+
|
213 |
+
class YuanRotaryEmbedding(torch.nn.Module):
|
214 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
215 |
+
|
216 |
+
"""
|
217 |
+
YuanRotaryEmbedding is equivalent to LlamaRotaryEmbedding in transformers v4.36
|
218 |
+
"""
|
219 |
+
|
220 |
+
super().__init__()
|
221 |
+
|
222 |
+
self.dim = dim
|
223 |
+
self.max_position_embeddings = max_position_embeddings
|
224 |
+
self.base = base
|
225 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
226 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
227 |
+
|
228 |
+
# Build here to make `torch.jit.trace` work.
|
229 |
+
self._set_cos_sin_cache(
|
230 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
231 |
+
)
|
232 |
|
233 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
234 |
+
self.max_seq_len_cached = seq_len
|
235 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
236 |
+
|
237 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
238 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
239 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
240 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
241 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
242 |
+
|
243 |
+
def forward(self, x, seq_len=None):
|
244 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
245 |
+
if seq_len > self.max_seq_len_cached:
|
246 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
247 |
+
|
248 |
+
return (
|
249 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
250 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
251 |
+
)
|
252 |
|
253 |
class YuanMLP(nn.Module):
|
254 |
def __init__(
|
|
|
291 |
)
|
292 |
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
293 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
294 |
+
self.rotary_emb = YuanRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
|
|
295 |
if self.use_shareqk:
|
296 |
self.qk_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
297 |
self.qk_weight = nn.Parameter(torch.Tensor(2, self.hidden_size))
|
|
|
443 |
intermediate_size=config.intermediate_size,
|
444 |
hidden_act=config.hidden_act,
|
445 |
)
|
446 |
+
self.input_layernorm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
447 |
+
self.post_attention_layernorm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
448 |
|
449 |
def forward(
|
450 |
self,
|
|
|
632 |
self.reset_position_ids = config.reset_position_ids
|
633 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
634 |
self.layers = nn.ModuleList([YuanDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
635 |
+
self.norm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
636 |
self.gradient_checkpointing = False
|
637 |
# Initialize weights and apply final processing
|
638 |
self.post_init()
|
yuan_hf_model_cpu.py
CHANGED
@@ -25,7 +25,6 @@ import torch
|
|
25 |
import torch.utils.checkpoint
|
26 |
from torch import nn
|
27 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
28 |
-
from transformers.models.llama.modeling_llama import LlamaRMSNorm,LlamaRotaryEmbedding
|
29 |
from transformers.activations import ACT2FN
|
30 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
31 |
from transformers.modeling_utils import PreTrainedModel
|
@@ -58,9 +57,7 @@ class LocalizedFiltering(torch.nn.Module):
|
|
58 |
|
59 |
self.conv1 = torch.nn.Conv2d(self.embed_dim, self.embed_dim // 2, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
|
60 |
self.conv2 = torch.nn.Conv2d(self.embed_dim // 2, self.embed_dim, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
|
61 |
-
|
62 |
-
#Use the same RMSNorm as llama
|
63 |
-
self.output_layernorm = LlamaRMSNorm(self.embed_dim)
|
64 |
|
65 |
def _train_forward(self, inputs):
|
66 |
inputs = inputs.transpose(0,1)
|
@@ -197,7 +194,61 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|
197 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
198 |
return q_embed, k_embed
|
199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
class YuanMLP(nn.Module):
|
203 |
def __init__(
|
@@ -240,8 +291,7 @@ class YuanAttention(nn.Module):
|
|
240 |
)
|
241 |
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
242 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
243 |
-
|
244 |
-
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
245 |
if self.use_shareqk:
|
246 |
self.qk_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
247 |
self.qk_weight = nn.Parameter(torch.Tensor(2, self.hidden_size))
|
@@ -268,7 +318,7 @@ class YuanAttention(nn.Module):
|
|
268 |
is_first_step = False
|
269 |
if use_cache:
|
270 |
if past_key_value is None:
|
271 |
-
#
|
272 |
inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype)
|
273 |
is_first_step = True
|
274 |
else:
|
@@ -393,9 +443,8 @@ class YuanDecoderLayer(nn.Module):
|
|
393 |
intermediate_size=config.intermediate_size,
|
394 |
hidden_act=config.hidden_act,
|
395 |
)
|
396 |
-
|
397 |
-
self.
|
398 |
-
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
399 |
|
400 |
def forward(
|
401 |
self,
|
@@ -583,8 +632,7 @@ class YuanModel(YuanPreTrainedModel):
|
|
583 |
self.reset_position_ids = config.reset_position_ids
|
584 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
585 |
self.layers = nn.ModuleList([YuanDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
586 |
-
|
587 |
-
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
588 |
self.gradient_checkpointing = False
|
589 |
# Initialize weights and apply final processing
|
590 |
self.post_init()
|
|
|
25 |
import torch.utils.checkpoint
|
26 |
from torch import nn
|
27 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
28 |
from transformers.activations import ACT2FN
|
29 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
30 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
57 |
|
58 |
self.conv1 = torch.nn.Conv2d(self.embed_dim, self.embed_dim // 2, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
|
59 |
self.conv2 = torch.nn.Conv2d(self.embed_dim // 2, self.embed_dim, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
|
60 |
+
self.output_layernorm = YuanRMSNorm(self.embed_dim)
|
|
|
|
|
61 |
|
62 |
def _train_forward(self, inputs):
|
63 |
inputs = inputs.transpose(0,1)
|
|
|
194 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
195 |
return q_embed, k_embed
|
196 |
|
197 |
+
class YuanRMSNorm(nn.Module):
|
198 |
+
def __init__(self, hidden_size, eps=1e-6):
|
199 |
+
"""
|
200 |
+
YuanRMSNorm is equivalent to LlamaRMSNorm
|
201 |
+
"""
|
202 |
+
super().__init__()
|
203 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
204 |
+
self.variance_epsilon = eps
|
205 |
+
|
206 |
+
def forward(self, hidden_states):
|
207 |
+
input_dtype = hidden_states.dtype
|
208 |
+
hidden_states = hidden_states.to(torch.float32)
|
209 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
210 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
211 |
+
return self.weight * hidden_states.to(input_dtype)
|
212 |
+
|
213 |
+
class YuanRotaryEmbedding(torch.nn.Module):
|
214 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
215 |
+
|
216 |
+
"""
|
217 |
+
YuanRotaryEmbedding is equivalent to LlamaRotaryEmbedding in transformers v4.36
|
218 |
+
"""
|
219 |
+
|
220 |
+
super().__init__()
|
221 |
+
|
222 |
+
self.dim = dim
|
223 |
+
self.max_position_embeddings = max_position_embeddings
|
224 |
+
self.base = base
|
225 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
226 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
227 |
+
|
228 |
+
# Build here to make `torch.jit.trace` work.
|
229 |
+
self._set_cos_sin_cache(
|
230 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
231 |
+
)
|
232 |
|
233 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
234 |
+
self.max_seq_len_cached = seq_len
|
235 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
236 |
+
|
237 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
238 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
239 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
240 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
241 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
242 |
+
|
243 |
+
def forward(self, x, seq_len=None):
|
244 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
245 |
+
if seq_len > self.max_seq_len_cached:
|
246 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
247 |
+
|
248 |
+
return (
|
249 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
250 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
251 |
+
)
|
252 |
|
253 |
class YuanMLP(nn.Module):
|
254 |
def __init__(
|
|
|
291 |
)
|
292 |
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
293 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
294 |
+
self.rotary_emb = YuanRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
|
|
295 |
if self.use_shareqk:
|
296 |
self.qk_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
297 |
self.qk_weight = nn.Parameter(torch.Tensor(2, self.hidden_size))
|
|
|
318 |
is_first_step = False
|
319 |
if use_cache:
|
320 |
if past_key_value is None:
|
321 |
+
#inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype ,device=torch.cuda.current_device())
|
322 |
inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype)
|
323 |
is_first_step = True
|
324 |
else:
|
|
|
443 |
intermediate_size=config.intermediate_size,
|
444 |
hidden_act=config.hidden_act,
|
445 |
)
|
446 |
+
self.input_layernorm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
447 |
+
self.post_attention_layernorm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
448 |
|
449 |
def forward(
|
450 |
self,
|
|
|
632 |
self.reset_position_ids = config.reset_position_ids
|
633 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
634 |
self.layers = nn.ModuleList([YuanDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
635 |
+
self.norm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
636 |
self.gradient_checkpointing = False
|
637 |
# Initialize weights and apply final processing
|
638 |
self.post_init()
|