MILVLG commited on
Commit
2524499
·
verified ·
1 Parent(s): cf4ec51

Upload 43 files

Browse files
Files changed (2) hide show
  1. modules/models/XMChat.py +12 -4
  2. modules/shared.py +0 -6
modules/models/XMChat.py CHANGED
@@ -19,6 +19,13 @@ from ..utils import *
19
  from .base_model import BaseLLMModel
20
  from .. import shared
21
 
 
 
 
 
 
 
 
22
  # print('model loading')
23
  # model = AutoModelForCausalLM.from_pretrained(
24
  # "/home/shaozw/labs/imp-v0",
@@ -173,16 +180,17 @@ A chat between a curious user and an artificial intelligence assistant. This art
173
  def get_answer_at_once(self):
174
  # question = self.history[-1]["content"].strip()
175
  # question = f"{self.system_prompt.strip()} USER: <image>\n{question} ASSISTANT:"
 
176
  prompt = self._get_imp_style_inputs()
177
  logging.info(prompt)
178
  # image_tok_cnt = prompt.count('<image>')
179
  # global model, tokenizer
180
- input_ids = shared.state.imp_tokenizer(prompt, return_tensors='pt').input_ids
181
  image_tensor = None
182
  if '<image>' in prompt:
183
  # logging.info("Preprocessing...")
184
- image_tensor = shared.state.imp_model.image_preprocess(self.image_bytes)
185
- output_ids = shared.state.imp_model.generate(
186
  input_ids,
187
  max_new_tokens=3000,
188
  images=image_tensor,
@@ -194,5 +202,5 @@ A chat between a curious user and an artificial intelligence assistant. This art
194
  # repetition_penalty=self.repetition_penalty,
195
  num_return_sequences=1,
196
  use_cache=True)[0]
197
- response = shared.state.imp_tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
198
  return response, len(response)
 
19
  from .base_model import BaseLLMModel
20
  from .. import shared
21
 
22
+ imp_model = AutoModelForCausalLM.from_pretrained(
23
+ "MILVLG/imp-v1-3b",
24
+ torch_dtype=torch.float16,
25
+ device_map="auto",
26
+ trust_remote_code=True)
27
+ imp_tokenizer = AutoTokenizer.from_pretrained("MILVLG/imp-v1-3b", trust_remote_code=True)
28
+
29
  # print('model loading')
30
  # model = AutoModelForCausalLM.from_pretrained(
31
  # "/home/shaozw/labs/imp-v0",
 
180
  def get_answer_at_once(self):
181
  # question = self.history[-1]["content"].strip()
182
  # question = f"{self.system_prompt.strip()} USER: <image>\n{question} ASSISTANT:"
183
+ global imp_model, imp_tokenizer
184
  prompt = self._get_imp_style_inputs()
185
  logging.info(prompt)
186
  # image_tok_cnt = prompt.count('<image>')
187
  # global model, tokenizer
188
+ input_ids = imp_tokenizer(prompt, return_tensors='pt').input_ids
189
  image_tensor = None
190
  if '<image>' in prompt:
191
  # logging.info("Preprocessing...")
192
+ image_tensor = imp_model.image_preprocess(self.image_bytes)
193
+ output_ids = imp_model.generate(
194
  input_ids,
195
  max_new_tokens=3000,
196
  images=image_tensor,
 
202
  # repetition_penalty=self.repetition_penalty,
203
  num_return_sequences=1,
204
  use_cache=True)[0]
205
+ response = imp_tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
206
  return response, len(response)
modules/shared.py CHANGED
@@ -16,12 +16,6 @@ class State:
16
  usage_api_url = USAGE_API_URL
17
  openai_api_base = OPENAI_API_BASE
18
  images_completion_url = IMAGES_COMPLETION_URL
19
- imp_model = AutoModelForCausalLM.from_pretrained(
20
- "MILVLG/imp-v1-3b",
21
- torch_dtype=torch.float16,
22
- device_map="auto",
23
- trust_remote_code=True)
24
- imp_tokenizer = AutoTokenizer.from_pretrained("MILVLG/imp-v1-3b", trust_remote_code=True)
25
 
26
  def interrupt(self):
27
  self.interrupted = True
 
16
  usage_api_url = USAGE_API_URL
17
  openai_api_base = OPENAI_API_BASE
18
  images_completion_url = IMAGES_COMPLETION_URL
 
 
 
 
 
 
19
 
20
  def interrupt(self):
21
  self.interrupted = True