finalf0 commited on
Commit
38b2f06
·
1 Parent(s): d253fbc

Update code

Browse files
Files changed (1) hide show
  1. modeling_minicpm.py +32 -5
modeling_minicpm.py CHANGED
@@ -20,7 +20,7 @@
20
  """ PyTorch MiniCPM model."""
21
  import math
22
  import warnings
23
- from typing import List, Optional, Tuple, Union
24
 
25
  import torch
26
  import torch.nn.functional as F
@@ -49,11 +49,13 @@ from transformers.utils import (
49
  )
50
  from transformers.utils.import_utils import is_torch_fx_available
51
  from .configuration_minicpm import MiniCPMConfig
 
52
 
53
-
54
- if is_flash_attn_2_available():
55
  from flash_attn import flash_attn_func, flash_attn_varlen_func
56
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
 
57
 
58
 
59
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
@@ -124,7 +126,7 @@ ALL_LAYERNORM_LAYERS.append(MiniCPMRMSNorm)
124
 
125
 
126
  class MiniCPMRotaryEmbedding(nn.Module):
127
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device="cuda"):
128
  super().__init__()
129
 
130
  self.dim = dim
@@ -762,7 +764,6 @@ class MiniCPMDecoderLayer(nn.Module):
762
  def __init__(self, config: MiniCPMConfig, layer_idx: int):
763
  super().__init__()
764
  self.hidden_size = config.hidden_size
765
-
766
  self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
767
 
768
  self.mlp = MiniCPMMLP(config)
@@ -1302,6 +1303,32 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1302
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1303
  )
1304
  return reordered_past
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1305
 
1306
 
1307
  @add_start_docstrings(
 
20
  """ PyTorch MiniCPM model."""
21
  import math
22
  import warnings
23
+ from typing import List, Optional, Tuple, Union, Dict
24
 
25
  import torch
26
  import torch.nn.functional as F
 
49
  )
50
  from transformers.utils.import_utils import is_torch_fx_available
51
  from .configuration_minicpm import MiniCPMConfig
52
+ import re
53
 
54
+ try:
 
55
  from flash_attn import flash_attn_func, flash_attn_varlen_func
56
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
57
+ except:
58
+ pass
59
 
60
 
61
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
 
126
 
127
 
128
  class MiniCPMRotaryEmbedding(nn.Module):
129
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
130
  super().__init__()
131
 
132
  self.dim = dim
 
764
  def __init__(self, config: MiniCPMConfig, layer_idx: int):
765
  super().__init__()
766
  self.hidden_size = config.hidden_size
 
767
  self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
768
 
769
  self.mlp = MiniCPMMLP(config)
 
1303
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1304
  )
1305
  return reordered_past
1306
+
1307
+ @torch.inference_mode()
1308
+ def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1309
+ max_length: int = 4096, num_beams=1, do_sample=True, top_p=0.8, temperature=0.3, logits_processor=None,
1310
+ **kwargs):
1311
+ if history is None:
1312
+ history = []
1313
+ if logits_processor:
1314
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1315
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1316
+ else:
1317
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1318
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1319
+
1320
+ history.append({"role": role, "content": query})
1321
+ history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=False)
1322
+ inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
1323
+ outputs = self.generate(**inputs, **gen_kwargs)
1324
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1325
+ response = tokenizer.decode(outputs)
1326
+ pattern = re.compile(r".*?(?=<AI>|<用户>)", re.DOTALL)
1327
+ matches = pattern.findall(response)
1328
+ if len(matches) > 0:
1329
+ response = matches[0]
1330
+ history.append({"role": "assistant", "content": response})
1331
+ return response, history
1332
 
1333
 
1334
  @add_start_docstrings(