TejAndrewsACC commited on
Commit
b39d091
·
verified ·
1 Parent(s): d40f491

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -19
app.py CHANGED
@@ -1,33 +1,58 @@
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer, StopStringCriteria, StoppingCriteriaList
2
  import torch
3
 
4
  # Load the tokenizer and model
5
  repo_name = "nvidia/Hymba-1.5B-Instruct"
6
-
7
  tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
8
  model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True)
9
- model = model.cuda().to(torch.bfloat16)
10
 
11
- # Chat with Hymba
12
- prompt = input()
13
 
 
14
  messages = [
15
  {"role": "system", "content": "You are a helpful assistant."}
16
  ]
17
- messages.append({"role": "user", "content": prompt})
18
-
19
- # Apply chat template
20
- tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to('cuda')
21
- stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings="</s>")])
22
- outputs = model.generate(
23
- tokenized_chat,
24
- max_new_tokens=256,
25
- do_sample=False,
26
- temperature=0.7,
27
- use_cache=True,
28
- stopping_criteria=stopping_criteria
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
30
- input_length = tokenized_chat.shape[1]
31
- response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
32
 
33
- print(f"Model response: {response}")
 
 
1
+ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, StopStringCriteria, StoppingCriteriaList
3
  import torch
4
 
5
  # Load the tokenizer and model
6
  repo_name = "nvidia/Hymba-1.5B-Instruct"
 
7
  tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
8
  model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True)
 
9
 
10
+ # Move the model to GPU with float16 precision for efficiency
11
+ model = model.to("cuda").to(torch.float16)
12
 
13
+ # Initialize the conversation history
14
  messages = [
15
  {"role": "system", "content": "You are a helpful assistant."}
16
  ]
17
+
18
+ # Define stopping criteria
19
+ stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings=["</s>"])])
20
+
21
+ # Chat function for Gradio interface
22
+ def chat_function(user_input):
23
+ # Add user message to the conversation history
24
+ messages.append({"role": "user", "content": user_input})
25
+
26
+ # Tokenize the conversation
27
+ tokenized_chat = tokenizer(messages, padding=True, truncation=True, return_tensors="pt").to("cuda")
28
+
29
+ # Generate a response
30
+ outputs = model.generate(
31
+ tokenized_chat["input_ids"],
32
+ max_new_tokens=256,
33
+ do_sample=False,
34
+ temperature=0.7,
35
+ use_cache=True,
36
+ stopping_criteria=stopping_criteria
37
+ )
38
+
39
+ # Decode the output response
40
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
+
42
+ # Add the assistant's response to the conversation history
43
+ messages.append({"role": "assistant", "content": response})
44
+
45
+ return response
46
+
47
+ # Set up Gradio interface with the chatbot template
48
+ iface = gr.Interface(
49
+ fn=chat_function,
50
+ inputs=gr.inputs.Textbox(label="Your message", placeholder="Enter your message here..."),
51
+ outputs=gr.outputs.Chatbot(),
52
+ live=True,
53
+ title="Hymba Chatbot",
54
+ description="Chat with the Hymba-1.5B-Instruct model!"
55
  )
 
 
56
 
57
+ # Launch the Gradio interface
58
+ iface.launch()