YenJung commited on
Commit
473ef6f
·
1 Parent(s): 71b9fae

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +150 -0
model.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from collections import namedtuple
3
+
4
+ import click
5
+ import torch
6
+ from peft import PeftModel
7
+ from transformers import (
8
+ AutoModel,
9
+ AutoTokenizer,
10
+ BloomForCausalLM,
11
+ BloomTokenizerFast,
12
+ GenerationConfig,
13
+ LlamaForCausalLM,
14
+ LlamaTokenizer,
15
+ )
16
+ from utils import generate_prompt
17
+
18
+
19
+ def decide_model(args, device_map):
20
+ ModelClass = namedtuple("ModelClass", ('tokenizer', 'model'))
21
+ _MODEL_CLASSES = {
22
+ "llama": ModelClass(**{
23
+ "tokenizer": LlamaTokenizer,
24
+ "model": LlamaForCausalLM,
25
+ }),
26
+ "chatglm": ModelClass(**{
27
+ "tokenizer": AutoTokenizer, #ChatGLMTokenizer,
28
+ "model": AutoModel, #ChatGLMForConditionalGeneration,
29
+ }),
30
+ "bloom": ModelClass(**{
31
+ "tokenizer": BloomTokenizerFast,
32
+ "model": BloomForCausalLM,
33
+ }),
34
+ "Auto": ModelClass(**{
35
+ "tokenizer": AutoTokenizer,
36
+ "model": AutoModel,
37
+ })
38
+ }
39
+ model_type = "Auto" if args.model_type not in ["llama", "bloom", "chatglm"] else args.model_type
40
+
41
+ if model_type == "chatglm":
42
+ tokenizer = _MODEL_CLASSES[model_type].tokenizer.from_pretrained(
43
+ args.base_model,
44
+ trust_remote_code=True
45
+ )
46
+ # todo: ChatGLMForConditionalGeneration revision
47
+ model = _MODEL_CLASSES[model_type].model.from_pretrained(
48
+ args.base_model,
49
+ trust_remote_code=True,
50
+ device_map=device_map
51
+ )
52
+ else:
53
+ tokenizer = _MODEL_CLASSES[model_type].tokenizer.from_pretrained(args.base_model)
54
+ model = _MODEL_CLASSES[model_type].model.from_pretrained(
55
+ args.base_model,
56
+ load_in_8bit=True,
57
+ torch_dtype=torch.float16,
58
+ device_map=device_map
59
+ )
60
+
61
+ if model_type == "llama":
62
+ tokenizer.pad_token_id = 0
63
+ tokenizer.padding_side = "left" # Allow batched inference
64
+
65
+ if device_map == "auto":
66
+ model = PeftModel.from_pretrained(
67
+ model,
68
+ args.finetuned_weights,
69
+ torch_dtype=torch.float16,
70
+ )
71
+ else:
72
+ model = PeftModel.from_pretrained(
73
+ model,
74
+ args.finetuned_weights,
75
+ device_map=device_map
76
+ )
77
+ return tokenizer, model
78
+
79
+
80
+ class ModelServe:
81
+ def __init__(
82
+ self,
83
+ load_8bit: bool = True,
84
+ model_type: str = "llama",
85
+ base_model: str = "linhvu/decapoda-research-llama-7b-hf",
86
+ finetuned_weights: str = "/home/holiday01/Downloads/LLaMa/alpaca-7b-chinese/finetuned/llama-7b-hf_alpaca-en-zh",
87
+ ):
88
+ args = locals()
89
+ namedtupler = namedtuple("args", tuple(list(args.keys())))
90
+ local_args = namedtupler(**args)
91
+
92
+ if torch.cuda.is_available():
93
+ self.device = "cuda:0"
94
+ self.device_map = "auto"
95
+ #self.max_memory = {i: "12GB" for i in range(torch.cuda.device_count())}
96
+ #self.max_memory.update({"cpu": "30GB"})
97
+ else:
98
+
99
+ self.device = "cpu"
100
+ self.device_map = {"": self.device}
101
+
102
+ self.tokenizer, self.model = decide_model(args=local_args, device_map=self.device_map)
103
+
104
+ # unwind broken decapoda-research config
105
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id = 0 # unk
106
+ self.model.config.bos_token_id = 1
107
+ self.model.config.eos_token_id = 2
108
+
109
+ if not load_8bit:
110
+ self.model.half() # seems to fix bugs for some users.
111
+
112
+ self.model.eval()
113
+ if torch.__version__ >= "2" and sys.platform != "win32":
114
+ self.model = torch.compile(self.model)
115
+
116
+ def generate(
117
+ self,
118
+ instruction: str,
119
+ input: str,
120
+ temperature: float = 0.7,
121
+ top_p: float = 0.75,
122
+ top_k: int = 40,
123
+ num_beams: int = 4,
124
+ max_new_tokens: int = 1024,
125
+ **kwargs
126
+ ):
127
+ prompt = generate_prompt(instruction, input)
128
+ print(f"Prompt: {prompt}")
129
+ inputs = self.tokenizer(prompt, return_tensors="pt")
130
+ input_ids = inputs["input_ids"].to(self.device)
131
+ generation_config = GenerationConfig(
132
+ temperature=temperature,
133
+ top_p=top_p,
134
+ top_k=top_k,
135
+ num_beams=num_beams,
136
+ **kwargs,
137
+ )
138
+ print("generating...")
139
+ with torch.no_grad():
140
+ generation_output = self.model.generate(
141
+ input_ids=input_ids,
142
+ generation_config=generation_config,
143
+ return_dict_in_generate=True,
144
+ output_scores=True,
145
+ max_new_tokens=max_new_tokens,
146
+ )
147
+ s = generation_output.sequences[0]
148
+ output = self.tokenizer.decode(s)
149
+ print(f"Output: {output}")
150
+ return output.split("### 回覆:")[1].strip()