hymba_chat / app.py
zaeemzafar's picture
Update app.py
039f5a8 verified
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import gradio as gr
# Load the tokenizer and model
repo_name = "nvidia/Hymba-1.5B-Base"
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True)
model = model.cuda().to(torch.bfloat16)
# Define the chatbot function
def chat_with_hymba(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
outputs = model.generate(**inputs, max_length=64, do_sample=True, temperature=0.7, use_cache=True)
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
return response
# Create Gradio Interface
interface = gr.Interface(
fn=chat_with_hymba,
inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
outputs="text",
title="Chat with Hymba",
description="Interact with the Hymba-1.5B model in real-time!"
)
# Launch the interface
if __name__ == "__main__":
interface.launch()