tangzhy commited on
Commit
e330996
·
verified ·
1 Parent(s): 55e52ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -40
app.py CHANGED
@@ -15,6 +15,8 @@ from transformers import (
15
  import subprocess
16
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
17
 
 
 
18
  DESCRIPTION = """\
19
  # ORLM LLaMA-3-8B
20
 
@@ -24,6 +26,7 @@ Hello! I'm ORLM-LLaMA-3-8B, here to automate your optimization modeling tasks! C
24
  MAX_MAX_NEW_TOKENS = 4096
25
  DEFAULT_MAX_NEW_TOKENS = 4096
26
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
27
 
28
  # quantization_config = BitsAndBytesConfig(
29
  # load_in_4bit=True,
@@ -32,19 +35,21 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
32
  # bnb_4bit_quant_type= "nf4")
33
  # quantization_config = BitsAndBytesConfig(load_in_8bit=True)
34
 
35
- model_id = "CardinalOperations/ORLM-LLaMA-3-8B"
36
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
37
- model = AutoModelForCausalLM.from_pretrained(
38
- model_id,
39
- device_map="auto",
40
- torch_dtype=torch.bfloat16,
41
- attn_implementation="flash_attention_2",
42
- # quantization_config=quantization_config,
43
- )
44
- model.eval()
45
-
46
-
47
- @spaces.GPU(duration=100)
 
 
48
  def generate(
49
  message: str,
50
  chat_history: list[tuple[str, str]],
@@ -57,33 +62,33 @@ def generate(
57
  if chat_history != []:
58
  return "Sorry, I am an instruction-tuned model and currently do not support chatting. Please try clearing the chat history or refreshing the page to ask a new question."
59
 
60
- tokenized_example = tokenizer(message, return_tensors='pt', max_length=MAX_INPUT_TOKEN_LENGTH, truncation=True)
61
- input_ids = tokenized_example.input_ids
62
- input_ids = input_ids.to(model.device)
63
-
64
- streamer = TextIteratorStreamer(tokenizer, timeout=50.0, skip_prompt=True, skip_special_tokens=True)
65
- generate_kwargs = dict(
66
- {"input_ids": input_ids},
67
- streamer=streamer,
68
- max_new_tokens=max_new_tokens,
69
- do_sample=False if temperature == 0.0 else True,
70
- top_p=top_p,
71
- top_k=top_k,
72
- temperature=temperature,
73
- num_beams=1,
74
- repetition_penalty=repetition_penalty,
75
- eos_token_id=[tok.eos_token_id],
76
- )
77
- t = Thread(target=model.generate, kwargs=generate_kwargs)
78
- t.start()
79
-
80
- outputs = []
81
- for text in streamer:
82
- outputs.append(text)
83
- yield "".join(outputs)
84
-
85
- # outputs.append("\n\nI have now attempted to solve the optimization modeling task! Please try executing the code in your environment, making sure it is equipped with `coptpy`.")
86
- # yield "".join(outputs)
87
 
88
 
89
  chat_interface = gr.ChatInterface(
 
15
  import subprocess
16
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
17
 
18
+ from vllm import LLM, SamplingParams
19
+
20
  DESCRIPTION = """\
21
  # ORLM LLaMA-3-8B
22
 
 
26
  MAX_MAX_NEW_TOKENS = 4096
27
  DEFAULT_MAX_NEW_TOKENS = 4096
28
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
29
+ model_id = "CardinalOperations/ORLM-LLaMA-3-8B"
30
 
31
  # quantization_config = BitsAndBytesConfig(
32
  # load_in_4bit=True,
 
35
  # bnb_4bit_quant_type= "nf4")
36
  # quantization_config = BitsAndBytesConfig(load_in_8bit=True)
37
 
38
+ # tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
39
+ # model = AutoModelForCausalLM.from_pretrained(
40
+ # model_id,
41
+ # device_map="auto",
42
+ # torch_dtype=torch.bfloat16,
43
+ # attn_implementation="flash_attention_2",
44
+ # # quantization_config=quantization_config,
45
+ # )
46
+ # model.eval()
47
+
48
+ subprocess.run(f'huggingface-cli download {model_id} --local_dir ./local_model', shell=True)
49
+ model = LLM(model='./local_model', tensor_parallel_size=torch.cuda.device_count())
50
+ print("init model done.")
51
+
52
+ @spaces.GPU(duration=60)
53
  def generate(
54
  message: str,
55
  chat_history: list[tuple[str, str]],
 
62
  if chat_history != []:
63
  return "Sorry, I am an instruction-tuned model and currently do not support chatting. Please try clearing the chat history or refreshing the page to ask a new question."
64
 
65
+ # tokenized_example = tokenizer(message, return_tensors='pt', max_length=MAX_INPUT_TOKEN_LENGTH, truncation=True)
66
+ # input_ids = tokenized_example.input_ids
67
+ # input_ids = input_ids.to(model.device)
68
+
69
+ # streamer = TextIteratorStreamer(tokenizer, timeout=50.0, skip_prompt=True, skip_special_tokens=True)
70
+ # generate_kwargs = dict(
71
+ # {"input_ids": input_ids},
72
+ # streamer=streamer,
73
+ # max_new_tokens=max_new_tokens,
74
+ # do_sample=False if temperature == 0.0 else True,
75
+ # top_p=top_p,
76
+ # top_k=top_k,
77
+ # temperature=temperature,
78
+ # num_beams=1,
79
+ # repetition_penalty=repetition_penalty,
80
+ # eos_token_id=[tok.eos_token_id],
81
+ # )
82
+
83
+ prompts = [message]
84
+ stop_tokens = ["</s>"]
85
+ if temperature == 0.0:
86
+ sampling_params = SamplingParams(n=topk, temperature=0, top_p=1, repetition_penalty=repetition_penalty, max_tokens=max_new_tokens, stop=stop_tokens)
87
+ else:
88
+ sampling_params = SamplingParams(n=topk, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, max_tokens=max_new_tokens, stop=stop_tokens)
89
+ generations = model.generate(prompts, sampling_params)
90
+ outputs = [g.outputs[0].text for g in generations]
91
+ return outputs[0]
92
 
93
 
94
  chat_interface = gr.ChatInterface(