nafisneehal's picture
Create app.py
addaa24 verified
raw
history blame
3.73 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
# File to store model links
MODEL_FILE = "model_links.txt"
def load_model_links():
# """Load model links from file"""
# if not os.path.exists(MODEL_FILE):
# # Create default file with some example models
# with open(MODEL_FILE, "w") as f:
# f.write("facebook/opt-125m\n")
# f.write("facebook/opt-350m\n")
with open(MODEL_FILE, "r") as f:
return [line.strip() for line in f.readlines() if line.strip()]
class ModelManager:
def __init__(self):
self.current_model = None
self.current_tokenizer = None
self.current_model_name = None
def load_model(self, model_name):
"""Load model and free previous model's memory"""
if self.current_model is not None:
del self.current_model
del self.current_tokenizer
torch.cuda.empty_cache()
self.current_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.current_model = AutoModelForCausalLM.from_pretrained(model_name)
self.current_model_name = model_name
return f"Loaded model: {model_name}"
def generate_response(self, system_message, user_message):
"""Generate response from the model"""
if self.current_model is None:
return "Please select and load a model first."
# Combine system and user messages
prompt = f"{system_message}\n\nUser: {user_message}\n\nAssistant:"
# Generate response
inputs = self.current_tokenizer(prompt, return_tensors="pt", padding=True)
outputs = self.current_model.generate(
inputs.input_ids,
max_length=200,
num_return_sequences=1,
temperature=0.7,
pad_token_id=self.current_tokenizer.eos_token_id
)
response = self.current_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response
response = response.split("Assistant:")[-1].strip()
return response
# Initialize model manager
model_manager = ModelManager()
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Chat Interface with Model Selection")
with gr.Row():
with gr.Column(scale=1):
# Input components
model_dropdown = gr.Dropdown(
choices=load_model_links(),
label="Select Model",
info="Choose a model from the list"
)
load_button = gr.Button("Load Selected Model")
system_msg = gr.Textbox(
label="System Message",
placeholder="Enter system message here...",
lines=3
)
user_msg = gr.Textbox(
label="User Message",
placeholder="Enter your message here...",
lines=3
)
submit_button = gr.Button("Generate Response")
with gr.Column(scale=1):
# Output components
model_status = gr.Textbox(label="Model Status")
chat_output = gr.Textbox(
label="Assistant Response",
lines=10,
interactive=False
)
# Event handlers
load_button.click(
fn=model_manager.load_model,
inputs=[model_dropdown],
outputs=[model_status]
)
submit_button.click(
fn=model_manager.generate_response,
inputs=[system_msg, user_msg],
outputs=[chat_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch()