SameerArz's picture
Update app.py
0576dea verified
raw
history blame
6.55 kB
import gradio as gr
from groq import Groq
import os
import threading
import base64
from io import BytesIO
from mistralai import Mistral # Pixtral-12B integration
# Initialize Groq client
client = Groq(api_key=os.environ["GROQ_API_KEY"])
# Initialize Mistral AI client (Pixtral-12B-2409 for VQA)
mistral_client = Mistral(api_key=os.environ["MISTRAL_API_KEY"])
pixtral_model = "pixtral-12b-2409"
# Load Text-to-Image Models
model1 = gr.load("models/prithivMLmods/SD3.5-Turbo-Realism-2.0-LoRA")
model2 = gr.load("models/Purz/face-projection")
# Stop event for threading (image generation)
stop_event = threading.Event()
# Convert PIL image to Base64
def pil_to_base64(pil_image, image_format='jpeg'):
buffered = BytesIO()
pil_image.save(buffered, format=image_format)
base64_string = base64.b64encode(buffered.getvalue()).decode('utf-8')
return base64_string, image_format
# Function to generate tutor output (lesson, question, feedback)
def generate_tutor_output(subject, difficulty, student_input):
prompt = f"""
You are an expert tutor in {subject} at the {difficulty} level.
The student has provided the following input: "{student_input}"
Please generate:
1. A brief, engaging lesson on the topic (2-3 paragraphs)
2. A thought-provoking question to check understanding
3. Constructive feedback on the student's input
Format your response as a JSON object with keys: "lesson", "question", "feedback"
"""
completion = client.chat.completions.create(
messages=[{
"role": "system",
"content": f"You are the world's best AI tutor, renowned for explaining complex concepts with clarity and examples. Your expertise in {subject} is unparalleled, and you're adept at tailoring your teaching to {difficulty} level students."
}, {
"role": "user",
"content": prompt,
}],
model="mixtral-8x7b-32768",
max_tokens=1000,
)
return completion.choices[0].message.content
# Function to generate images based on model selection
def generate_images(text, selected_model):
stop_event.clear()
model = model1 if selected_model == "Model 1 (Turbo Realism)" else model2 if selected_model == "Model 2 (Face Projection)" else None
if not model:
return ["Invalid model selection."] * 3
results = []
for i in range(3):
if stop_event.is_set():
return ["Image generation stopped by user."] * 3
modified_text = f"{text} variation {i+1}"
result = model(modified_text)
results.append(result)
return results
# Function for Visual Question Answering (Pixtral-12B)
def answer_question(text, image, temperature=0.0, max_tokens=1024):
base64_string, file_format = pil_to_base64(image)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": text},
{"type": "image_url", "image_url": f"data:image/{file_format};base64,{base64_string}"}
]
}
]
chat_response = mistral_client.chat.complete(
model=pixtral_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
return chat_response.choices[0].message.content
# Clear all fields
def clear_all():
return "", None, ""
# Set up the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# 🎓 AI Tutor & Visual Learning Assistant")
# Section 1: Text-based Learning
with gr.Row():
with gr.Column(scale=2):
subject = gr.Dropdown(["Math", "Science", "History", "Literature", "Code", "AI"], label="Subject")
difficulty = gr.Radio(["Beginner", "Intermediate", "Advanced"], label="Difficulty Level")
student_input = gr.Textbox(placeholder="Type your query here...", label="Your Input")
submit_button_text = gr.Button("Generate Lesson & Question", variant="primary")
with gr.Column(scale=3):
lesson_output = gr.Markdown(label="Lesson")
question_output = gr.Markdown(label="Comprehension Question")
feedback_output = gr.Markdown(label="Feedback")
# Section 2: Image Generation
with gr.Row():
with gr.Column(scale=2):
model_selector = gr.Radio(
["Model 1 (Turbo Realism)", "Model 2 (Face Projection)"],
label="Select Image Generation Model",
value="Model 1 (Turbo Realism)"
)
submit_button_visual = gr.Button("Generate Visuals", variant="primary")
with gr.Column(scale=3):
output1 = gr.Image(label="Generated Image 1")
output2 = gr.Image(label="Generated Image 2")
output3 = gr.Image(label="Generated Image 3")
# Section 3: Visual Question Answering (Pixtral-12B)
gr.Markdown("## 🖼️ Visual Question Answering (Pixtral-12B)")
with gr.Row():
with gr.Column(scale=2):
question = gr.Textbox(placeholder="Ask about the image...", lines=2)
image = gr.Image(type="pil")
with gr.Row():
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.0, step=0.1)
max_tokens = gr.Slider(label="Max Tokens", minimum=128, maximum=2048, value=1024, step=128)
with gr.Column(scale=3):
output_text = gr.Textbox(lines=10, label="Pixtral 12B Response")
with gr.Row():
clear_btn = gr.Button("Clear", variant="secondary")
submit_btn_vqa = gr.Button("Submit", variant="primary")
# Generate Text-based Output
submit_button_text.click(
fn=lambda subject, difficulty, student_input: eval(generate_tutor_output(subject, difficulty, student_input)),
inputs=[subject, difficulty, student_input],
outputs=[lesson_output, question_output, feedback_output]
)
# Generate Visual Output
submit_button_visual.click(
fn=generate_images,
inputs=[student_input, model_selector],
outputs=[output1, output2, output3]
)
# VQA Processing
submit_btn_vqa.click(
fn=answer_question,
inputs=[question, image, temperature, max_tokens],
outputs=[output_text]
)
# Clear VQA Inputs
clear_btn.click(
fn=clear_all,
inputs=[],
outputs=[question, image, output_text]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)