tangzhy commited on
Commit
b2b7f7a
·
verified ·
1 Parent(s): d3128ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -47
app.py CHANGED
@@ -15,18 +15,14 @@ from transformers import (
15
  import subprocess
16
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
17
 
18
- from vllm import LLM, SamplingParams
19
-
20
  DESCRIPTION = """\
21
  # ORLM LLaMA-3-8B
22
-
23
  Hello! I'm ORLM-LLaMA-3-8B, here to automate your optimization modeling tasks! Check our [repo](https://github.com/Cardinal-Operations/ORLM) and [paper](https://arxiv.org/abs/2405.17743)!
24
  """
25
 
26
  MAX_MAX_NEW_TOKENS = 4096
27
  DEFAULT_MAX_NEW_TOKENS = 4096
28
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
29
- model_id = "CardinalOperations/ORLM-LLaMA-3-8B"
30
 
31
  # quantization_config = BitsAndBytesConfig(
32
  # load_in_4bit=True,
@@ -35,21 +31,19 @@ model_id = "CardinalOperations/ORLM-LLaMA-3-8B"
35
  # bnb_4bit_quant_type= "nf4")
36
  # quantization_config = BitsAndBytesConfig(load_in_8bit=True)
37
 
38
- # tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
39
- # model = AutoModelForCausalLM.from_pretrained(
40
- # model_id,
41
- # device_map="auto",
42
- # torch_dtype=torch.bfloat16,
43
- # attn_implementation="flash_attention_2",
44
- # # quantization_config=quantization_config,
45
- # )
46
- # model.eval()
47
-
48
- subprocess.run(f'huggingface-cli download {model_id} --local-dir ./local_model', shell=True)
49
- model = LLM(model='./local_model', tensor_parallel_size=1)
50
- print("init model done.")
51
-
52
- @spaces.GPU(duration=60)
53
  def generate(
54
  message: str,
55
  chat_history: list[tuple[str, str]],
@@ -62,33 +56,33 @@ def generate(
62
  if chat_history != []:
63
  return "Sorry, I am an instruction-tuned model and currently do not support chatting. Please try clearing the chat history or refreshing the page to ask a new question."
64
 
65
- # tokenized_example = tokenizer(message, return_tensors='pt', max_length=MAX_INPUT_TOKEN_LENGTH, truncation=True)
66
- # input_ids = tokenized_example.input_ids
67
- # input_ids = input_ids.to(model.device)
68
-
69
- # streamer = TextIteratorStreamer(tokenizer, timeout=50.0, skip_prompt=True, skip_special_tokens=True)
70
- # generate_kwargs = dict(
71
- # {"input_ids": input_ids},
72
- # streamer=streamer,
73
- # max_new_tokens=max_new_tokens,
74
- # do_sample=False if temperature == 0.0 else True,
75
- # top_p=top_p,
76
- # top_k=top_k,
77
- # temperature=temperature,
78
- # num_beams=1,
79
- # repetition_penalty=repetition_penalty,
80
- # eos_token_id=[tok.eos_token_id],
81
- # )
82
-
83
- prompts = [message]
84
- stop_tokens = ["</s>"]
85
- if temperature == 0.0:
86
- sampling_params = SamplingParams(n=topk, temperature=0, top_p=1, repetition_penalty=repetition_penalty, max_tokens=max_new_tokens, stop=stop_tokens)
87
- else:
88
- sampling_params = SamplingParams(n=topk, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, max_tokens=max_new_tokens, stop=stop_tokens)
89
- generations = model.generate(prompts, sampling_params)
90
- outputs = [g.outputs[0].text for g in generations]
91
- return outputs[0]
92
 
93
 
94
  chat_interface = gr.ChatInterface(
@@ -144,4 +138,4 @@ with gr.Blocks(css="style.css", fill_height=True) as demo:
144
  chat_interface.render()
145
 
146
  if __name__ == "__main__":
147
- demo.queue(max_size=20).launch()
 
15
  import subprocess
16
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
17
 
 
 
18
  DESCRIPTION = """\
19
  # ORLM LLaMA-3-8B
 
20
  Hello! I'm ORLM-LLaMA-3-8B, here to automate your optimization modeling tasks! Check our [repo](https://github.com/Cardinal-Operations/ORLM) and [paper](https://arxiv.org/abs/2405.17743)!
21
  """
22
 
23
  MAX_MAX_NEW_TOKENS = 4096
24
  DEFAULT_MAX_NEW_TOKENS = 4096
25
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
26
 
27
  # quantization_config = BitsAndBytesConfig(
28
  # load_in_4bit=True,
 
31
  # bnb_4bit_quant_type= "nf4")
32
  # quantization_config = BitsAndBytesConfig(load_in_8bit=True)
33
 
34
+ model_id = "CardinalOperations/ORLM-LLaMA-3-8B"
35
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ model_id,
38
+ device_map="auto",
39
+ torch_dtype=torch.bfloat16,
40
+ attn_implementation="flash_attention_2",
41
+ # quantization_config=quantization_config,
42
+ )
43
+ model.eval()
44
+
45
+
46
+ @spaces.GPU(duration=100)
 
 
47
  def generate(
48
  message: str,
49
  chat_history: list[tuple[str, str]],
 
56
  if chat_history != []:
57
  return "Sorry, I am an instruction-tuned model and currently do not support chatting. Please try clearing the chat history or refreshing the page to ask a new question."
58
 
59
+ tokenized_example = tokenizer(message, return_tensors='pt', max_length=MAX_INPUT_TOKEN_LENGTH, truncation=True)
60
+ input_ids = tokenized_example.input_ids
61
+ input_ids = input_ids.to(model.device)
62
+
63
+ streamer = TextIteratorStreamer(tokenizer, timeout=50.0, skip_prompt=True, skip_special_tokens=True)
64
+ generate_kwargs = dict(
65
+ {"input_ids": input_ids},
66
+ streamer=streamer,
67
+ max_new_tokens=max_new_tokens,
68
+ do_sample=False if temperature == 0.0 else True,
69
+ top_p=top_p,
70
+ top_k=top_k,
71
+ temperature=temperature,
72
+ num_beams=1,
73
+ repetition_penalty=repetition_penalty,
74
+ eos_token_id=[tok.eos_token_id],
75
+ )
76
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
77
+ t.start()
78
+
79
+ outputs = []
80
+ for text in streamer:
81
+ outputs.append(text)
82
+ yield "".join(outputs)
83
+
84
+ # outputs.append("\n\nI have now attempted to solve the optimization modeling task! Please try executing the code in your environment, making sure it is equipped with `coptpy`.")
85
+ # yield "".join(outputs)
86
 
87
 
88
  chat_interface = gr.ChatInterface(
 
138
  chat_interface.render()
139
 
140
  if __name__ == "__main__":
141
+ demo.queue(max_size=20).launch()