Upload model.py
Browse files
model.py
CHANGED
@@ -23,7 +23,7 @@ class RMSNorm(torch.nn.Module):
|
|
23 |
return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x)
|
24 |
|
25 |
|
26 |
-
def precompute_pos_cis(dim: int, end: int, theta: float =
|
27 |
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
28 |
t = torch.arange(end, device=freqs.device) # type: ignore
|
29 |
freqs = torch.outer(t, freqs).float() # type: ignore
|
@@ -295,8 +295,9 @@ class MiniMindLM(PreTrainedModel):
|
|
295 |
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
296 |
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
297 |
self.tok_embeddings.weight = self.output.weight
|
298 |
-
self.register_buffer("pos_cis",
|
299 |
-
|
|
|
300 |
self.OUT = CausalLMOutputWithPast()
|
301 |
|
302 |
def forward(self,
|
@@ -328,13 +329,13 @@ class MiniMindLM(PreTrainedModel):
|
|
328 |
stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
|
329 |
# 流式生成
|
330 |
if stream:
|
331 |
-
return self.
|
332 |
|
333 |
# 直接生成
|
334 |
generated = []
|
335 |
for i in range(input_ids.size(0)):
|
336 |
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
337 |
-
out = self.
|
338 |
tokens_list = [tokens[:, -1:] for tokens in out]
|
339 |
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
340 |
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
@@ -348,14 +349,14 @@ class MiniMindLM(PreTrainedModel):
|
|
348 |
]
|
349 |
return torch.cat(generated, dim=0)
|
350 |
|
351 |
-
def
|
352 |
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
353 |
while input_ids.shape[1] < max_new_tokens - 1:
|
354 |
if first_seq or not use_cache:
|
355 |
-
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache), False
|
356 |
else:
|
357 |
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
|
358 |
-
start_pos=input_ids.shape[1] - 1)
|
359 |
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
360 |
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
361 |
logits /= (temperature + 1e-9)
|
|
|
23 |
return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x)
|
24 |
|
25 |
|
26 |
+
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
27 |
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
28 |
t = torch.arange(end, device=freqs.device) # type: ignore
|
29 |
freqs = torch.outer(t, freqs).float() # type: ignore
|
|
|
295 |
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
296 |
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
297 |
self.tok_embeddings.weight = self.output.weight
|
298 |
+
self.register_buffer("pos_cis",
|
299 |
+
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
300 |
+
persistent=False)
|
301 |
self.OUT = CausalLMOutputWithPast()
|
302 |
|
303 |
def forward(self,
|
|
|
329 |
stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
|
330 |
# 流式生成
|
331 |
if stream:
|
332 |
+
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
333 |
|
334 |
# 直接生成
|
335 |
generated = []
|
336 |
for i in range(input_ids.size(0)):
|
337 |
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
338 |
+
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
339 |
tokens_list = [tokens[:, -1:] for tokens in out]
|
340 |
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
341 |
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
|
|
349 |
]
|
350 |
return torch.cat(generated, dim=0)
|
351 |
|
352 |
+
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
|
353 |
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
354 |
while input_ids.shape[1] < max_new_tokens - 1:
|
355 |
if first_seq or not use_cache:
|
356 |
+
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False
|
357 |
else:
|
358 |
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
|
359 |
+
start_pos=input_ids.shape[1] - 1, **args)
|
360 |
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
361 |
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
362 |
logits /= (temperature + 1e-9)
|