YassoCodes's picture
Update app.py
fddf668 verified
import gradio as gr
import transformers
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from huggingface_hub import login
login(token ="HF_TOKEN")
def predict(input, history=[]):
"""Processes user input and potentially leverages history for improved predictions.
Args:
input (str): User's input text.
history (list, optional): List of previous inputs and outputs for context (default: []).
Returns:
tuple: A tuple containing the chatbot response and the updated history (optional).
"""
# Replace with your actual Gemma prediction logic here
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # Assuming you're using Transformers
# Assuming you've loaded the Gemma model weights
model_name = "google/gemma-1.1-7b-it"
model = AutoModelForSeq2SeqLM.from_pretrained("google/gemma-1.1-7b-it")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-1.1-7b-it")
# Process user input using Gemma
inputs = tokenizer(input, return_tensors="pt")
generated_text = model.generate(**inputs)
chatbot_response = tokenizer.decode(generated_text[0], skip_special_tokens=True)
return chatbot_response, history # Return response and optionally updated history
# Create the Gradio interface
interface = gr.Interface(
fn=predict,
inputs=["textbox", "state"], # "state" input can be removed if not used
outputs=["chatbot", "state"] # Remove "state" output if history is not used
)
# Load the model within the Gradio interface context
try:
gr.load("models/google/gemma-1.1-7b-it") # Assuming model weights are available
except Exception as e:
print(f"An error occurred while loading the model: {e}") # Improved error handling
# Launch the Gradio interface
interface.launch()