SameerArz commited on
Commit
0576dea
·
verified ·
1 Parent(s): 3383ecb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -50
app.py CHANGED
@@ -1,11 +1,18 @@
1
  import gradio as gr
2
  from groq import Groq
3
  import os
4
- import threading # Import threading module
 
 
 
5
 
6
- # Initialize Groq client with your API key
7
  client = Groq(api_key=os.environ["GROQ_API_KEY"])
8
 
 
 
 
 
9
  # Load Text-to-Image Models
10
  model1 = gr.load("models/prithivMLmods/SD3.5-Turbo-Realism-2.0-LoRA")
11
  model2 = gr.load("models/Purz/face-projection")
@@ -13,6 +20,15 @@ model2 = gr.load("models/Purz/face-projection")
13
  # Stop event for threading (image generation)
14
  stop_event = threading.Event()
15
 
 
 
 
 
 
 
 
 
 
16
  # Function to generate tutor output (lesson, question, feedback)
17
  def generate_tutor_output(subject, difficulty, student_input):
18
  prompt = f"""
@@ -30,26 +46,24 @@ def generate_tutor_output(subject, difficulty, student_input):
30
  completion = client.chat.completions.create(
31
  messages=[{
32
  "role": "system",
33
- "content": f"You are the world's best AI tutor, renowned for your ability to explain complex concepts in an engaging, clear, and memorable way and giving math examples. Your expertise in {subject} is unparalleled, and you're adept at tailoring your teaching to {difficulty} level students."
34
  }, {
35
  "role": "user",
36
  "content": prompt,
37
  }],
38
- model="mixtral-8x7b-32768", # Model for text generation
39
  max_tokens=1000,
40
  )
41
 
42
  return completion.choices[0].message.content
43
 
 
44
  # Function to generate images based on model selection
45
  def generate_images(text, selected_model):
46
  stop_event.clear()
47
 
48
- if selected_model == "Model 1 (Turbo Realism)":
49
- model = model1
50
- elif selected_model == "Model 2 (Face Projection)":
51
- model = model2
52
- else:
53
  return ["Invalid model selection."] * 3
54
 
55
  results = []
@@ -63,28 +77,46 @@ def generate_images(text, selected_model):
63
 
64
  return results
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Set up the Gradio interface
67
  with gr.Blocks() as demo:
68
- gr.Markdown("# 🎓 Your AI Tutor with Visuals & Images")
69
 
70
- # Section for generating Text-based output (lesson, question, feedback)
71
  with gr.Row():
72
  with gr.Column(scale=2):
73
- subject = gr.Dropdown(
74
- ["Math", "Science", "History", "Literature", "Code", "AI"],
75
- label="Subject",
76
- info="Choose the subject of your lesson"
77
- )
78
- difficulty = gr.Radio(
79
- ["Beginner", "Intermediate", "Advanced"],
80
- label="Difficulty Level",
81
- info="Select your proficiency level"
82
- )
83
- student_input = gr.Textbox(
84
- placeholder="Type your query here...",
85
- label="Your Input",
86
- info="Enter the topic you want to learn"
87
- )
88
  submit_button_text = gr.Button("Generate Lesson & Question", variant="primary")
89
 
90
  with gr.Column(scale=3):
@@ -92,7 +124,7 @@ with gr.Blocks() as demo:
92
  question_output = gr.Markdown(label="Comprehension Question")
93
  feedback_output = gr.Markdown(label="Feedback")
94
 
95
- # Section for generating Visual output
96
  with gr.Row():
97
  with gr.Column(scale=2):
98
  model_selector = gr.Radio(
@@ -107,39 +139,50 @@ with gr.Blocks() as demo:
107
  output2 = gr.Image(label="Generated Image 2")
108
  output3 = gr.Image(label="Generated Image 3")
109
 
110
- gr.Markdown("""
111
- ### How to Use
112
- 1. **Text Section**: Select a subject and difficulty, type your query, and click 'Generate Lesson & Question' to get your personalized lesson, comprehension question, and feedback.
113
- 2. **Visual Section**: Select the model for image generation, then click 'Generate Visuals' to receive 3 variations of an image based on your topic.
114
- 3. Review the AI-generated content to enhance your learning experience!
115
- """)
116
-
117
- def process_output_text(subject, difficulty, student_input):
118
- try:
119
- tutor_output = generate_tutor_output(subject, difficulty, student_input)
120
- parsed = eval(tutor_output) # Convert string to dictionary
121
- return parsed["lesson"], parsed["question"], parsed["feedback"]
122
- except:
123
- return "Error parsing output", "No question available", "No feedback available"
124
-
125
- def process_output_visual(text, selected_model):
126
- try:
127
- images = generate_images(text, selected_model)
128
- return images[0], images[1], images[2]
129
- except:
130
- return None, None, None
131
 
 
 
 
 
 
132
  submit_button_text.click(
133
- fn=process_output_text,
134
  inputs=[subject, difficulty, student_input],
135
  outputs=[lesson_output, question_output, feedback_output]
136
  )
137
 
 
138
  submit_button_visual.click(
139
- fn=process_output_visual,
140
  inputs=[student_input, model_selector],
141
  outputs=[output1, output2, output3]
142
  )
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  if __name__ == "__main__":
145
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
  from groq import Groq
3
  import os
4
+ import threading
5
+ import base64
6
+ from io import BytesIO
7
+ from mistralai import Mistral # Pixtral-12B integration
8
 
9
+ # Initialize Groq client
10
  client = Groq(api_key=os.environ["GROQ_API_KEY"])
11
 
12
+ # Initialize Mistral AI client (Pixtral-12B-2409 for VQA)
13
+ mistral_client = Mistral(api_key=os.environ["MISTRAL_API_KEY"])
14
+ pixtral_model = "pixtral-12b-2409"
15
+
16
  # Load Text-to-Image Models
17
  model1 = gr.load("models/prithivMLmods/SD3.5-Turbo-Realism-2.0-LoRA")
18
  model2 = gr.load("models/Purz/face-projection")
 
20
  # Stop event for threading (image generation)
21
  stop_event = threading.Event()
22
 
23
+
24
+ # Convert PIL image to Base64
25
+ def pil_to_base64(pil_image, image_format='jpeg'):
26
+ buffered = BytesIO()
27
+ pil_image.save(buffered, format=image_format)
28
+ base64_string = base64.b64encode(buffered.getvalue()).decode('utf-8')
29
+ return base64_string, image_format
30
+
31
+
32
  # Function to generate tutor output (lesson, question, feedback)
33
  def generate_tutor_output(subject, difficulty, student_input):
34
  prompt = f"""
 
46
  completion = client.chat.completions.create(
47
  messages=[{
48
  "role": "system",
49
+ "content": f"You are the world's best AI tutor, renowned for explaining complex concepts with clarity and examples. Your expertise in {subject} is unparalleled, and you're adept at tailoring your teaching to {difficulty} level students."
50
  }, {
51
  "role": "user",
52
  "content": prompt,
53
  }],
54
+ model="mixtral-8x7b-32768",
55
  max_tokens=1000,
56
  )
57
 
58
  return completion.choices[0].message.content
59
 
60
+
61
  # Function to generate images based on model selection
62
  def generate_images(text, selected_model):
63
  stop_event.clear()
64
 
65
+ model = model1 if selected_model == "Model 1 (Turbo Realism)" else model2 if selected_model == "Model 2 (Face Projection)" else None
66
+ if not model:
 
 
 
67
  return ["Invalid model selection."] * 3
68
 
69
  results = []
 
77
 
78
  return results
79
 
80
+
81
+ # Function for Visual Question Answering (Pixtral-12B)
82
+ def answer_question(text, image, temperature=0.0, max_tokens=1024):
83
+ base64_string, file_format = pil_to_base64(image)
84
+
85
+ messages = [
86
+ {
87
+ "role": "user",
88
+ "content": [
89
+ {"type": "text", "text": text},
90
+ {"type": "image_url", "image_url": f"data:image/{file_format};base64,{base64_string}"}
91
+ ]
92
+ }
93
+ ]
94
+
95
+ chat_response = mistral_client.chat.complete(
96
+ model=pixtral_model,
97
+ messages=messages,
98
+ temperature=temperature,
99
+ max_tokens=max_tokens
100
+ )
101
+
102
+ return chat_response.choices[0].message.content
103
+
104
+
105
+ # Clear all fields
106
+ def clear_all():
107
+ return "", None, ""
108
+
109
+
110
  # Set up the Gradio interface
111
  with gr.Blocks() as demo:
112
+ gr.Markdown("# 🎓 AI Tutor & Visual Learning Assistant")
113
 
114
+ # Section 1: Text-based Learning
115
  with gr.Row():
116
  with gr.Column(scale=2):
117
+ subject = gr.Dropdown(["Math", "Science", "History", "Literature", "Code", "AI"], label="Subject")
118
+ difficulty = gr.Radio(["Beginner", "Intermediate", "Advanced"], label="Difficulty Level")
119
+ student_input = gr.Textbox(placeholder="Type your query here...", label="Your Input")
 
 
 
 
 
 
 
 
 
 
 
 
120
  submit_button_text = gr.Button("Generate Lesson & Question", variant="primary")
121
 
122
  with gr.Column(scale=3):
 
124
  question_output = gr.Markdown(label="Comprehension Question")
125
  feedback_output = gr.Markdown(label="Feedback")
126
 
127
+ # Section 2: Image Generation
128
  with gr.Row():
129
  with gr.Column(scale=2):
130
  model_selector = gr.Radio(
 
139
  output2 = gr.Image(label="Generated Image 2")
140
  output3 = gr.Image(label="Generated Image 3")
141
 
142
+ # Section 3: Visual Question Answering (Pixtral-12B)
143
+ gr.Markdown("## 🖼️ Visual Question Answering (Pixtral-12B)")
144
+ with gr.Row():
145
+ with gr.Column(scale=2):
146
+ question = gr.Textbox(placeholder="Ask about the image...", lines=2)
147
+ image = gr.Image(type="pil")
148
+ with gr.Row():
149
+ temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.0, step=0.1)
150
+ max_tokens = gr.Slider(label="Max Tokens", minimum=128, maximum=2048, value=1024, step=128)
151
+
152
+ with gr.Column(scale=3):
153
+ output_text = gr.Textbox(lines=10, label="Pixtral 12B Response")
 
 
 
 
 
 
 
 
 
154
 
155
+ with gr.Row():
156
+ clear_btn = gr.Button("Clear", variant="secondary")
157
+ submit_btn_vqa = gr.Button("Submit", variant="primary")
158
+
159
+ # Generate Text-based Output
160
  submit_button_text.click(
161
+ fn=lambda subject, difficulty, student_input: eval(generate_tutor_output(subject, difficulty, student_input)),
162
  inputs=[subject, difficulty, student_input],
163
  outputs=[lesson_output, question_output, feedback_output]
164
  )
165
 
166
+ # Generate Visual Output
167
  submit_button_visual.click(
168
+ fn=generate_images,
169
  inputs=[student_input, model_selector],
170
  outputs=[output1, output2, output3]
171
  )
172
 
173
+ # VQA Processing
174
+ submit_btn_vqa.click(
175
+ fn=answer_question,
176
+ inputs=[question, image, temperature, max_tokens],
177
+ outputs=[output_text]
178
+ )
179
+
180
+ # Clear VQA Inputs
181
+ clear_btn.click(
182
+ fn=clear_all,
183
+ inputs=[],
184
+ outputs=[question, image, output_text]
185
+ )
186
+
187
  if __name__ == "__main__":
188
  demo.launch(server_name="0.0.0.0", server_port=7860)