File size: 1,296 Bytes
605de75
43a23d4
605de75
43a23d4
605de75
 
43a23d4
605de75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load Mistral-7B Model
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, 
    torch_dtype=torch.float16,  
    device_map="auto"
)

# Chatbot function
def chat_with_mistral(prompt, history=[]):
    history.append(f"User: {prompt}")  
    input_text = "\n".join(history) + "\nAssistant:"

    inputs = tokenizer(input_text, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
    output = model.generate(**inputs, max_new_tokens=200)
    response = tokenizer.decode(output[0], skip_special_tokens=True).split("Assistant:")[-1].strip()

    history.append(f"Assistant: {response}")
    return response, history

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("### 🤖 Mistral-7B Chatbot on Hugging Face Spaces")
    
    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="Type your message here...")
    clear = gr.Button("Clear Chat")

    history = gr.State([])

    msg.submit(chat_with_mistral, inputs=[msg, history], outputs=[chatbot, history])
    clear.click(lambda: ([], []), inputs=[], outputs=[chatbot, history])

# Launch Gradio app
demo.launch()