reshinthadith commited on
Commit
f541eb3
1 Parent(s): 2ddd665

Add cpu inference option for testing

Browse files
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -4,10 +4,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, Stopping
4
  import time
5
  import numpy as np
6
  from torch.nn import functional as F
7
-
8
-
9
- m = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-tuned-alpha-7b", torch_dtype=torch.float16).cuda()
10
- tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-7b")
 
 
 
 
 
11
  generator = pipeline('text-generation', model=m, tokenizer=tok, device=0)
12
 
13
 
@@ -29,8 +34,12 @@ class StopOnTokens(StoppingCriteria):
29
 
30
  def contrastive_generate(text, bad_text):
31
  with torch.no_grad():
32
- tokens = tok(text, return_tensors="pt")['input_ids'].cuda()[:,:4096-1024]
33
- bad_tokens = tok(bad_text, return_tensors="pt")['input_ids'].cuda()[:,:4096-1024]
 
 
 
 
34
  history = None
35
  bad_history = None
36
  curr_output = list()
@@ -83,12 +92,13 @@ def system_update(msg):
83
 
84
 
85
  with gr.Blocks() as demo:
 
86
  with gr.Row():
87
  with gr.Column():
88
  chatbot = gr.Chatbot([])
89
  clear = gr.Button("Clear")
90
  with gr.Column():
91
- system_msg = gr.Textbox(start_message, label="System Message", interactive=True)
92
  msg = gr.Textbox(label="Chat Message")
93
 
94
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
 
4
  import time
5
  import numpy as np
6
  from torch.nn import functional as F
7
+ import os
8
+ token_key = os.environ.get(“HUGGING_FACE_HUB_TOKEN”)
9
+
10
+ if torch.cuda.is_available():
11
+ m = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-tuned-alpha-7b",use_auth_token=token_key, torch_dtype=torch.float16).cuda()
12
+ tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-7b",use_auth_token=token_key)
13
+ else:
14
+ m = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-tuned-alpha-7b",use_auth_token=token_key, torch_dtype=torch.float16)
15
+ tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-7b",use_auth_token=token_key)
16
  generator = pipeline('text-generation', model=m, tokenizer=tok, device=0)
17
 
18
 
 
34
 
35
  def contrastive_generate(text, bad_text):
36
  with torch.no_grad():
37
+ if torch.cuda_is_available():
38
+ tokens = tok(text, return_tensors="pt")['input_ids'].cuda()[:,:4096-1024]
39
+ bad_tokens = tok(bad_text, return_tensors="pt")['input_ids'].cuda()[:,:4096-1024]
40
+ else:
41
+ tokens = tok(text, return_tensors="pt")['input_ids'][:,:4096-1024]
42
+ bad_tokens = tok(bad_text, return_tensors="pt")['input_ids'][:,:4096-1024]
43
  history = None
44
  bad_history = None
45
  curr_output = list()
 
92
 
93
 
94
  with gr.Blocks() as demo:
95
+ gr.Markdown("###StableLM-tuned-Alpha-7B Chat")
96
  with gr.Row():
97
  with gr.Column():
98
  chatbot = gr.Chatbot([])
99
  clear = gr.Button("Clear")
100
  with gr.Column():
101
+ system_msg = start_message#gr.Textbox(start_message, label="System Message", interactive=True)
102
  msg = gr.Textbox(label="Chat Message")
103
 
104
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(