CreitinGameplays commited on
Commit
d6af013
·
verified ·
1 Parent(s): 335ddf3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -20
app.py CHANGED
@@ -3,7 +3,6 @@ from threading import Thread
3
  from typing import Iterator
4
 
5
  import gradio as gr
6
- import spaces
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
@@ -15,19 +14,34 @@ DESCRIPTION = """\
15
  # ConvAI 9b v2 Chat
16
  """
17
 
18
- if not torch.cuda.is_available():
19
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
20
-
21
-
22
- if torch.cuda.is_available():
23
- model_id = "CreitinGameplays/ConvAI-9b-v2"
24
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
 
 
 
 
 
 
 
 
 
 
 
 
25
  tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
26
  tokenizer.use_default_system_prompt = False
 
 
 
 
27
 
28
- system_prompt_text = "You are a helpful, respectful and honest AI assistant named ChatGPT. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don’t know the answer to a question, please don’t share false information."
29
 
30
- @spaces.GPU(duration=90)
31
  def generate(
32
  message: str,
33
  chat_history: list[tuple[str, str]],
@@ -49,11 +63,11 @@ def generate(
49
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
50
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
51
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
52
- input_ids = input_ids.to(model.device)
53
 
54
- streamer = TextIteratorStreamer(tokenizer, timeout=5.0, skip_prompt=True, skip_special_tokens=True)
55
  generate_kwargs = dict(
56
- {"input_ids": input_ids},
57
  streamer=streamer,
58
  max_new_tokens=max_new_tokens,
59
  do_sample=True,
@@ -71,7 +85,6 @@ def generate(
71
  outputs.append(text)
72
  yield "".join(outputs)
73
 
74
-
75
  chat_interface = gr.ChatInterface(
76
  fn=generate,
77
  additional_inputs=[
@@ -122,13 +135,8 @@ chat_interface = gr.ChatInterface(
122
  ],
123
  )
124
 
125
- with gr.Blocks(css="style.css") as demo:
126
  gr.Markdown(DESCRIPTION)
127
- gr.DuplicateButton(
128
- value="Duplicate Space for private use",
129
- elem_id="duplicate-button",
130
- visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
131
- )
132
  chat_interface.render()
133
 
134
  if __name__ == "__main__":
 
3
  from typing import Iterator
4
 
5
  import gradio as gr
 
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
 
 
14
  # ConvAI 9b v2 Chat
15
  """
16
 
17
+ # Load model with appropriate device configuration
18
+ def load_model():
19
+ model_id = "CreitinGameplays/dumbbot"
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ # If using CPU, load in 32-bit to avoid potential issues with 16-bit operations
23
+ if device == "cpu":
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_id,
26
+ torch_dtype=torch.float32,
27
+ low_cpu_mem_usage=True
28
+ )
29
+ else:
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ model_id,
32
+ torch_dtype=torch.float16,
33
+ device_map="auto"
34
+ )
35
+
36
  tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
37
  tokenizer.use_default_system_prompt = False
38
+
39
+ return model, tokenizer, device
40
+
41
+ model, tokenizer, device = load_model()
42
 
43
+ system_prompt_text = "You are Ricardinho."
44
 
 
45
  def generate(
46
  message: str,
47
  chat_history: list[tuple[str, str]],
 
63
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
64
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
65
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
66
+ input_ids = input_ids.to(device)
67
 
68
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
69
  generate_kwargs = dict(
70
+ input_ids=input_ids,
71
  streamer=streamer,
72
  max_new_tokens=max_new_tokens,
73
  do_sample=True,
 
85
  outputs.append(text)
86
  yield "".join(outputs)
87
 
 
88
  chat_interface = gr.ChatInterface(
89
  fn=generate,
90
  additional_inputs=[
 
135
  ],
136
  )
137
 
138
+ with gr.Blocks() as demo:
139
  gr.Markdown(DESCRIPTION)
 
 
 
 
 
140
  chat_interface.render()
141
 
142
  if __name__ == "__main__":