File size: 6,550 Bytes
e4f36a2
 
 
0576dea
 
 
 
e4f36a2
0576dea
f987a4c
e4f36a2
0576dea
 
 
 
78c1c43
 
 
 
 
 
 
0576dea
 
 
 
 
 
 
 
 
78c1c43
e4f36a2
 
 
 
 
 
 
 
 
 
 
 
78c1c43
e4f36a2
78c1c43
 
0576dea
78c1c43
 
 
 
0576dea
e4f36a2
 
78c1c43
e4f36a2
 
0576dea
78c1c43
 
 
 
0576dea
 
78c1c43
 
 
 
 
 
 
 
 
 
 
 
 
0576dea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78c1c43
e4f36a2
0576dea
78c1c43
0576dea
e4f36a2
 
0576dea
 
 
78c1c43
e4f36a2
 
 
 
 
 
0576dea
78c1c43
 
 
 
 
 
 
 
 
 
 
 
 
 
0576dea
 
 
 
 
 
 
 
 
 
 
 
78c1c43
0576dea
 
 
 
 
78c1c43
0576dea
e4f36a2
 
 
78c1c43
0576dea
78c1c43
0576dea
78c1c43
 
 
e4f36a2
0576dea
 
 
 
 
 
 
 
 
 
 
 
 
 
e4f36a2
78c1c43
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import gradio as gr
from groq import Groq
import os
import threading
import base64
from io import BytesIO
from mistralai import Mistral  # Pixtral-12B integration

# Initialize Groq client
client = Groq(api_key=os.environ["GROQ_API_KEY"])

# Initialize Mistral AI client (Pixtral-12B-2409 for VQA)
mistral_client = Mistral(api_key=os.environ["MISTRAL_API_KEY"])
pixtral_model = "pixtral-12b-2409"

# Load Text-to-Image Models
model1 = gr.load("models/prithivMLmods/SD3.5-Turbo-Realism-2.0-LoRA")
model2 = gr.load("models/Purz/face-projection")

# Stop event for threading (image generation)
stop_event = threading.Event()


# Convert PIL image to Base64
def pil_to_base64(pil_image, image_format='jpeg'):
    buffered = BytesIO()
    pil_image.save(buffered, format=image_format)
    base64_string = base64.b64encode(buffered.getvalue()).decode('utf-8')
    return base64_string, image_format


# Function to generate tutor output (lesson, question, feedback)
def generate_tutor_output(subject, difficulty, student_input):
    prompt = f"""
    You are an expert tutor in {subject} at the {difficulty} level. 
    The student has provided the following input: "{student_input}"
    
    Please generate:
    1. A brief, engaging lesson on the topic (2-3 paragraphs)
    2. A thought-provoking question to check understanding
    3. Constructive feedback on the student's input
    
    Format your response as a JSON object with keys: "lesson", "question", "feedback"
    """
    
    completion = client.chat.completions.create(
        messages=[{
            "role": "system",
            "content": f"You are the world's best AI tutor, renowned for explaining complex concepts with clarity and examples. Your expertise in {subject} is unparalleled, and you're adept at tailoring your teaching to {difficulty} level students."
        }, {
            "role": "user",
            "content": prompt,
        }],
        model="mixtral-8x7b-32768",
        max_tokens=1000,
    )
    
    return completion.choices[0].message.content


# Function to generate images based on model selection
def generate_images(text, selected_model):
    stop_event.clear()

    model = model1 if selected_model == "Model 1 (Turbo Realism)" else model2 if selected_model == "Model 2 (Face Projection)" else None
    if not model:
        return ["Invalid model selection."] * 3

    results = []
    for i in range(3):
        if stop_event.is_set():
            return ["Image generation stopped by user."] * 3

        modified_text = f"{text} variation {i+1}"
        result = model(modified_text)
        results.append(result)

    return results


# Function for Visual Question Answering (Pixtral-12B)
def answer_question(text, image, temperature=0.0, max_tokens=1024):
    base64_string, file_format = pil_to_base64(image)

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": text},
                {"type": "image_url", "image_url": f"data:image/{file_format};base64,{base64_string}"}
            ]
        }
    ]
    
    chat_response = mistral_client.chat.complete(
        model=pixtral_model,
        messages=messages,
        temperature=temperature,
        max_tokens=max_tokens
    )

    return chat_response.choices[0].message.content


# Clear all fields
def clear_all():
    return "", None, ""


# Set up the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# 🎓 AI Tutor & Visual Learning Assistant")

    # Section 1: Text-based Learning
    with gr.Row():
        with gr.Column(scale=2):
            subject = gr.Dropdown(["Math", "Science", "History", "Literature", "Code", "AI"], label="Subject")
            difficulty = gr.Radio(["Beginner", "Intermediate", "Advanced"], label="Difficulty Level")
            student_input = gr.Textbox(placeholder="Type your query here...", label="Your Input")
            submit_button_text = gr.Button("Generate Lesson & Question", variant="primary")
        
        with gr.Column(scale=3):
            lesson_output = gr.Markdown(label="Lesson")
            question_output = gr.Markdown(label="Comprehension Question")
            feedback_output = gr.Markdown(label="Feedback")
    
    # Section 2: Image Generation
    with gr.Row():
        with gr.Column(scale=2):
            model_selector = gr.Radio(
                ["Model 1 (Turbo Realism)", "Model 2 (Face Projection)"],
                label="Select Image Generation Model",
                value="Model 1 (Turbo Realism)"
            )
            submit_button_visual = gr.Button("Generate Visuals", variant="primary")
        
        with gr.Column(scale=3):
            output1 = gr.Image(label="Generated Image 1")
            output2 = gr.Image(label="Generated Image 2")
            output3 = gr.Image(label="Generated Image 3")
    
    # Section 3: Visual Question Answering (Pixtral-12B)
    gr.Markdown("## 🖼️ Visual Question Answering (Pixtral-12B)")
    with gr.Row():
        with gr.Column(scale=2):
            question = gr.Textbox(placeholder="Ask about the image...", lines=2)
            image = gr.Image(type="pil")
            with gr.Row():
                temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.0, step=0.1)
                max_tokens = gr.Slider(label="Max Tokens", minimum=128, maximum=2048, value=1024, step=128)
        
        with gr.Column(scale=3):
            output_text = gr.Textbox(lines=10, label="Pixtral 12B Response")
    
    with gr.Row():
        clear_btn = gr.Button("Clear", variant="secondary")
        submit_btn_vqa = gr.Button("Submit", variant="primary")

    # Generate Text-based Output
    submit_button_text.click(
        fn=lambda subject, difficulty, student_input: eval(generate_tutor_output(subject, difficulty, student_input)),
        inputs=[subject, difficulty, student_input],
        outputs=[lesson_output, question_output, feedback_output]
    )
    
    # Generate Visual Output
    submit_button_visual.click(
        fn=generate_images,
        inputs=[student_input, model_selector],
        outputs=[output1, output2, output3]
    )

    # VQA Processing
    submit_btn_vqa.click(
        fn=answer_question,
        inputs=[question, image, temperature, max_tokens],
        outputs=[output_text]
    )

    # Clear VQA Inputs
    clear_btn.click(
        fn=clear_all,
        inputs=[],
        outputs=[question, image, output_text]
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)