import gradio as gr from groq import Groq import os import threading import base64 from io import BytesIO from mistralai import Mistral # Pixtral-12B integration # Initialize Groq client client = Groq(api_key=os.environ["GROQ_API_KEY"]) # Initialize Mistral AI client (Pixtral-12B-2409 for VQA) mistral_client = Mistral(api_key=os.environ["MISTRAL_API_KEY"]) pixtral_model = "pixtral-12b-2409" # Load Text-to-Image Models model1 = gr.load("models/prithivMLmods/SD3.5-Turbo-Realism-2.0-LoRA") model2 = gr.load("models/Purz/face-projection") # Stop event for threading (image generation) stop_event = threading.Event() # Convert PIL image to Base64 def pil_to_base64(pil_image, image_format='jpeg'): buffered = BytesIO() pil_image.save(buffered, format=image_format) base64_string = base64.b64encode(buffered.getvalue()).decode('utf-8') return base64_string, image_format # Function to generate tutor output (lesson, question, feedback) def generate_tutor_output(subject, difficulty, student_input): prompt = f""" You are an expert tutor in {subject} at the {difficulty} level. The student has provided the following input: "{student_input}" Please generate: 1. A brief, engaging lesson on the topic (2-3 paragraphs) 2. A thought-provoking question to check understanding 3. Constructive feedback on the student's input Format your response as a JSON object with keys: "lesson", "question", "feedback" """ completion = client.chat.completions.create( messages=[{ "role": "system", "content": f"You are the world's best AI tutor, renowned for explaining complex concepts with clarity and examples. Your expertise in {subject} is unparalleled, and you're adept at tailoring your teaching to {difficulty} level students." }, { "role": "user", "content": prompt, }], model="mixtral-8x7b-32768", max_tokens=1000, ) return completion.choices[0].message.content # Function to generate images based on model selection def generate_images(text, selected_model): stop_event.clear() model = model1 if selected_model == "Model 1 (Turbo Realism)" else model2 if selected_model == "Model 2 (Face Projection)" else None if not model: return ["Invalid model selection."] * 3 results = [] for i in range(3): if stop_event.is_set(): return ["Image generation stopped by user."] * 3 modified_text = f"{text} variation {i+1}" result = model(modified_text) results.append(result) return results # Function for Visual Question Answering (Pixtral-12B) def answer_question(text, image, temperature=0.0, max_tokens=1024): base64_string, file_format = pil_to_base64(image) messages = [ { "role": "user", "content": [ {"type": "text", "text": text}, {"type": "image_url", "image_url": f"data:image/{file_format};base64,{base64_string}"} ] } ] chat_response = mistral_client.chat.complete( model=pixtral_model, messages=messages, temperature=temperature, max_tokens=max_tokens ) return chat_response.choices[0].message.content # Clear all fields def clear_all(): return "", None, "" # Set up the Gradio interface with gr.Blocks() as demo: gr.Markdown("# 🎓 AI Tutor & Visual Learning Assistant") # Section 1: Text-based Learning with gr.Row(): with gr.Column(scale=2): subject = gr.Dropdown(["Math", "Science", "History", "Literature", "Code", "AI"], label="Subject") difficulty = gr.Radio(["Beginner", "Intermediate", "Advanced"], label="Difficulty Level") student_input = gr.Textbox(placeholder="Type your query here...", label="Your Input") submit_button_text = gr.Button("Generate Lesson & Question", variant="primary") with gr.Column(scale=3): lesson_output = gr.Markdown(label="Lesson") question_output = gr.Markdown(label="Comprehension Question") feedback_output = gr.Markdown(label="Feedback") # Section 2: Image Generation with gr.Row(): with gr.Column(scale=2): model_selector = gr.Radio( ["Model 1 (Turbo Realism)", "Model 2 (Face Projection)"], label="Select Image Generation Model", value="Model 1 (Turbo Realism)" ) submit_button_visual = gr.Button("Generate Visuals", variant="primary") with gr.Column(scale=3): output1 = gr.Image(label="Generated Image 1") output2 = gr.Image(label="Generated Image 2") output3 = gr.Image(label="Generated Image 3") # Section 3: Visual Question Answering (Pixtral-12B) gr.Markdown("## 🖼️ Visual Question Answering (Pixtral-12B)") with gr.Row(): with gr.Column(scale=2): question = gr.Textbox(placeholder="Ask about the image...", lines=2) image = gr.Image(type="pil") with gr.Row(): temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.0, step=0.1) max_tokens = gr.Slider(label="Max Tokens", minimum=128, maximum=2048, value=1024, step=128) with gr.Column(scale=3): output_text = gr.Textbox(lines=10, label="Pixtral 12B Response") with gr.Row(): clear_btn = gr.Button("Clear", variant="secondary") submit_btn_vqa = gr.Button("Submit", variant="primary") # Generate Text-based Output submit_button_text.click( fn=lambda subject, difficulty, student_input: eval(generate_tutor_output(subject, difficulty, student_input)), inputs=[subject, difficulty, student_input], outputs=[lesson_output, question_output, feedback_output] ) # Generate Visual Output submit_button_visual.click( fn=generate_images, inputs=[student_input, model_selector], outputs=[output1, output2, output3] ) # VQA Processing submit_btn_vqa.click( fn=answer_question, inputs=[question, image, temperature, max_tokens], outputs=[output_text] ) # Clear VQA Inputs clear_btn.click( fn=clear_all, inputs=[], outputs=[question, image, output_text] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)