File size: 3,731 Bytes
addaa24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os

# File to store model links
MODEL_FILE = "model_links.txt"

def load_model_links():
    # """Load model links from file"""
    # if not os.path.exists(MODEL_FILE):
    #     # Create default file with some example models
    #     with open(MODEL_FILE, "w") as f:
    #         f.write("facebook/opt-125m\n")
    #         f.write("facebook/opt-350m\n")
    
    with open(MODEL_FILE, "r") as f:
        return [line.strip() for line in f.readlines() if line.strip()]

class ModelManager:
    def __init__(self):
        self.current_model = None
        self.current_tokenizer = None
        self.current_model_name = None
    
    def load_model(self, model_name):
        """Load model and free previous model's memory"""
        if self.current_model is not None:
            del self.current_model
            del self.current_tokenizer
            torch.cuda.empty_cache()
        
        self.current_tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.current_model = AutoModelForCausalLM.from_pretrained(model_name)
        self.current_model_name = model_name
        return f"Loaded model: {model_name}"
    
    def generate_response(self, system_message, user_message):
        """Generate response from the model"""
        if self.current_model is None:
            return "Please select and load a model first."
        
        # Combine system and user messages
        prompt = f"{system_message}\n\nUser: {user_message}\n\nAssistant:"
        
        # Generate response
        inputs = self.current_tokenizer(prompt, return_tensors="pt", padding=True)
        outputs = self.current_model.generate(
            inputs.input_ids,
            max_length=200,
            num_return_sequences=1,
            temperature=0.7,
            pad_token_id=self.current_tokenizer.eos_token_id
        )
        
        response = self.current_tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Extract only the assistant's response
        response = response.split("Assistant:")[-1].strip()
        return response

# Initialize model manager
model_manager = ModelManager()

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Chat Interface with Model Selection")
    
    with gr.Row():
        with gr.Column(scale=1):
            # Input components
            model_dropdown = gr.Dropdown(
                choices=load_model_links(),
                label="Select Model",
                info="Choose a model from the list"
            )
            load_button = gr.Button("Load Selected Model")
            system_msg = gr.Textbox(
                label="System Message",
                placeholder="Enter system message here...",
                lines=3
            )
            user_msg = gr.Textbox(
                label="User Message",
                placeholder="Enter your message here...",
                lines=3
            )
            submit_button = gr.Button("Generate Response")
        
        with gr.Column(scale=1):
            # Output components
            model_status = gr.Textbox(label="Model Status")
            chat_output = gr.Textbox(
                label="Assistant Response",
                lines=10,
                interactive=False
            )
    
    # Event handlers
    load_button.click(
        fn=model_manager.load_model,
        inputs=[model_dropdown],
        outputs=[model_status]
    )
    
    submit_button.click(
        fn=model_manager.generate_response,
        inputs=[system_msg, user_msg],
        outputs=[chat_output]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()