Ozaii commited on
Commit
30dbf5f
·
verified ·
1 Parent(s): f7b393f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
  from peft import PeftConfig, PeftModel
4
- from threading import Thread
5
  import gradio as gr
6
  import spaces
7
 
@@ -25,7 +24,6 @@ def load_model():
25
  BASE_MODEL,
26
  torch_dtype=torch.float16,
27
  device_map="auto",
28
- load_in_4bit=True,
29
  trust_remote_code=True
30
  )
31
 
@@ -43,29 +41,27 @@ def load_model():
43
  def generate_response(prompt, max_new_tokens=128):
44
  model, tokenizer = load_model()
45
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
46
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
47
- generation_kwargs = dict(
48
- input_ids=inputs.input_ids,
49
  max_new_tokens=max_new_tokens,
50
  temperature=0.7,
51
  top_p=0.9,
52
  repetition_penalty=1.2,
53
- streamer=streamer,
54
  )
55
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
56
- thread.start()
57
- return streamer
58
 
59
  def chat_with_zephyr(message, history):
60
  conversation_history = history[-3:] # Limit to last 3 exchanges
61
  full_prompt = "\n".join([f"Human: {h[0]}\nZephyr: {h[1]}" for h in conversation_history])
62
  full_prompt += f"\nHuman: {message}\nZephyr:"
63
 
64
- streamer = generate_response(full_prompt)
65
- response = ""
66
- for new_text in streamer:
67
- response += new_text
68
- yield response
 
69
 
70
  css = """
71
  body {
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftConfig, PeftModel
 
4
  import gradio as gr
5
  import spaces
6
 
 
24
  BASE_MODEL,
25
  torch_dtype=torch.float16,
26
  device_map="auto",
 
27
  trust_remote_code=True
28
  )
29
 
 
41
  def generate_response(prompt, max_new_tokens=128):
42
  model, tokenizer = load_model()
43
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
44
+ outputs = model.generate(
45
+ **inputs,
 
46
  max_new_tokens=max_new_tokens,
47
  temperature=0.7,
48
  top_p=0.9,
49
  repetition_penalty=1.2,
50
+ do_sample=True
51
  )
52
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
53
 
54
  def chat_with_zephyr(message, history):
55
  conversation_history = history[-3:] # Limit to last 3 exchanges
56
  full_prompt = "\n".join([f"Human: {h[0]}\nZephyr: {h[1]}" for h in conversation_history])
57
  full_prompt += f"\nHuman: {message}\nZephyr:"
58
 
59
+ response = generate_response(full_prompt)
60
+
61
+ # Extract Zephyr's response
62
+ zephyr_response = response.split("Zephyr:")[-1].strip()
63
+
64
+ return zephyr_response
65
 
66
  css = """
67
  body {