alfredplpl commited on
Commit
4fffa9e
·
verified ·
1 Parent(s): 44c960c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -9
app.py CHANGED
@@ -45,7 +45,7 @@ h1 {
45
 
46
  # Load the tokenizer and model
47
  tokenizer = AutoTokenizer.from_pretrained("llm-jp/llm-jp-13b-instruct-full-ac_001_16x-dolly-ichikara_004_001_single-oasst-oasst2-v2.0")
48
- model = AutoModelForCausalLM.from_pretrained("llm-jp/llm-jp-13b-instruct-full-ac_001_16x-dolly-ichikara_004_001_single-oasst-oasst2-v2.0", device_map="auto", torch_dtype=torch.bfloat16)
49
  model=model.eval()
50
 
51
  @spaces.GPU()
@@ -70,16 +70,30 @@ def chat_llm_jp_v2(message: str,
70
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
71
  conversation.append({"role": "user", "content": message})
72
 
73
- tokenized_input = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=True, return_tensors="pt").to(model.device)
74
- output = model.generate(
75
- tokenized_input,
 
 
76
  max_new_tokens=max_new_tokens,
77
  do_sample=True,
78
- top_p=0.95,
79
  temperature=temperature,
80
- repetition_penalty=1.05,
81
- )[0]
82
- return tokenizer.decode(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
 
85
  # Gradio block
@@ -97,7 +111,7 @@ with gr.Blocks(fill_height=True, css=css) as demo:
97
  additional_inputs=[
98
  gr.Slider(minimum=0.1,
99
  maximum=1,
100
- step=0.1,
101
  value=0.7,
102
  label="Temperature",
103
  render=False),
 
45
 
46
  # Load the tokenizer and model
47
  tokenizer = AutoTokenizer.from_pretrained("llm-jp/llm-jp-13b-instruct-full-ac_001_16x-dolly-ichikara_004_001_single-oasst-oasst2-v2.0")
48
+ model = AutoModelForCausalLM.from_pretrained("llm-jp/llm-jp-13b-instruct-full-ac_001_16x-dolly-ichikara_004_001_single-oasst-oasst2-v2.0", device_map="cuda", torch_dtype=torch.bfloat16)
49
  model=model.eval()
50
 
51
  @spaces.GPU()
 
70
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
71
  conversation.append({"role": "user", "content": message})
72
 
73
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
74
+
75
+ generate_kwargs = dict(
76
+ input_ids= input_ids,
77
+ streamer=streamer,
78
  max_new_tokens=max_new_tokens,
79
  do_sample=True,
 
80
  temperature=temperature,
81
+ top_p=0.95,
82
+ repetition_penalty=1.1,
83
+ eos_token_id=terminators,
84
+ )
85
+ # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
86
+ if temperature == 0:
87
+ generate_kwargs['do_sample'] = False
88
+
89
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
90
+ t.start()
91
+
92
+ outputs = []
93
+ for text in streamer:
94
+ outputs.append(text)
95
+ print(outputs)
96
+ yield "".join(outputs)
97
 
98
 
99
  # Gradio block
 
111
  additional_inputs=[
112
  gr.Slider(minimum=0.1,
113
  maximum=1,
114
+ step=0.0,
115
  value=0.7,
116
  label="Temperature",
117
  render=False),