Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import os | |
# File to store model links | |
MODEL_FILE = "model_links.txt" | |
def load_model_links(): | |
# """Load model links from file""" | |
# if not os.path.exists(MODEL_FILE): | |
# # Create default file with some example models | |
# with open(MODEL_FILE, "w") as f: | |
# f.write("facebook/opt-125m\n") | |
# f.write("facebook/opt-350m\n") | |
with open(MODEL_FILE, "r") as f: | |
return [line.strip() for line in f.readlines() if line.strip()] | |
class ModelManager: | |
def __init__(self): | |
self.current_model = None | |
self.current_tokenizer = None | |
self.current_model_name = None | |
def load_model(self, model_name): | |
"""Load model and free previous model's memory""" | |
if self.current_model is not None: | |
del self.current_model | |
del self.current_tokenizer | |
torch.cuda.empty_cache() | |
self.current_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.current_model = AutoModelForCausalLM.from_pretrained(model_name) | |
self.current_model_name = model_name | |
return f"Loaded model: {model_name}" | |
def generate_response(self, system_message, user_message): | |
"""Generate response from the model""" | |
if self.current_model is None: | |
return "Please select and load a model first." | |
# Combine system and user messages | |
prompt = f"{system_message}\n\nUser: {user_message}\n\nAssistant:" | |
# Generate response | |
inputs = self.current_tokenizer(prompt, return_tensors="pt", padding=True) | |
outputs = self.current_model.generate( | |
inputs.input_ids, | |
max_length=200, | |
num_return_sequences=1, | |
temperature=0.7, | |
pad_token_id=self.current_tokenizer.eos_token_id | |
) | |
response = self.current_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the assistant's response | |
response = response.split("Assistant:")[-1].strip() | |
return response | |
# Initialize model manager | |
model_manager = ModelManager() | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Chat Interface with Model Selection") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Input components | |
model_dropdown = gr.Dropdown( | |
choices=load_model_links(), | |
label="Select Model", | |
info="Choose a model from the list" | |
) | |
load_button = gr.Button("Load Selected Model") | |
system_msg = gr.Textbox( | |
label="System Message", | |
placeholder="Enter system message here...", | |
lines=3 | |
) | |
user_msg = gr.Textbox( | |
label="User Message", | |
placeholder="Enter your message here...", | |
lines=3 | |
) | |
submit_button = gr.Button("Generate Response") | |
with gr.Column(scale=1): | |
# Output components | |
model_status = gr.Textbox(label="Model Status") | |
chat_output = gr.Textbox( | |
label="Assistant Response", | |
lines=10, | |
interactive=False | |
) | |
# Event handlers | |
load_button.click( | |
fn=model_manager.load_model, | |
inputs=[model_dropdown], | |
outputs=[model_status] | |
) | |
submit_button.click( | |
fn=model_manager.generate_response, | |
inputs=[system_msg, user_msg], | |
outputs=[chat_output] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |