Spaces:
Sleeping
Sleeping
import os | |
import json | |
import gradio as gr | |
import google.generativeai as genai | |
from PIL import Image | |
import numpy as np | |
from huggingface_hub import HfFolder | |
from dotenv import load_dotenv | |
import traceback | |
import pytesseract | |
import cv2 | |
import time | |
# Load environment variables | |
load_dotenv() | |
# Set API key for Gemini | |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") or HfFolder.get_token("GEMINI_API_KEY") | |
if not GEMINI_API_KEY: | |
raise ValueError("Gemini API key not found. Please set the GEMINI_API_KEY environment variable.") | |
genai.configure(api_key=GEMINI_API_KEY) | |
# Define model names - using latest models | |
CLASSIFICATION_MODEL = "gemini-1.5-flash" # For classification | |
SOLUTION_MODEL = "gemini-1.5-pro-latest" # For solution generation | |
EXPLANATION_MODEL = "gemini-1.5-pro-latest" # For explanation generation | |
SIMILAR_MODEL = "gemini-1.5-pro-latest" # For similar problems generation | |
print(f"Using models: Classification: {CLASSIFICATION_MODEL}, Solution: {SOLUTION_MODEL}, Explanation: {EXPLANATION_MODEL}, Similar: {SIMILAR_MODEL}") | |
# Set up Gemini for image analysis | |
MODEL_IMAGE = "gemini-1.5-pro-latest" # Use Gemini for OCR as well | |
# Set Tesseract path - Mac with Homebrew default | |
pytesseract.pytesseract.tesseract_cmd = '/opt/homebrew/bin/tesseract' | |
# Extract text using Gemini directly (with Tesseract as fallback) | |
def extract_text_with_gemini(image): | |
"""Extract text from image using Gemini Pro Vision directly""" | |
try: | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
model = genai.GenerativeModel(MODEL_IMAGE) | |
prompt = """ | |
Extract ALL text, numbers, and mathematical equations from this image precisely. | |
Include ALL symbols, numbers, letters, and mathematical notation exactly as they appear. | |
Format any equations properly and maintain their layout. | |
Don't explain the content, just extract the text verbatim. | |
""" | |
response = model.generate_content([prompt, image]) | |
extracted_text = response.text.strip() | |
# If Gemini returns a very short result, try Tesseract as fallback | |
if len(extracted_text) < 10: | |
print("Gemini returned limited text, trying Tesseract as fallback") | |
if isinstance(image, Image.Image): | |
image_array = np.array(image) | |
else: | |
image_array = image | |
if len(image_array.shape) == 3: | |
gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY) | |
else: | |
gray = image_array | |
custom_config = r'--oem 1 --psm 6' | |
tesseract_text = pytesseract.image_to_string(gray, config=custom_config) | |
if len(tesseract_text) > len(extracted_text): | |
extracted_text = tesseract_text | |
print(f"Extracted text: {extracted_text[:100]}...") | |
return extracted_text | |
except Exception as e: | |
print(f"Extraction Error: {e}") | |
print(traceback.format_exc()) | |
try: | |
if isinstance(image, Image.Image): | |
image_array = np.array(image) | |
else: | |
image_array = image | |
if len(image_array.shape) == 3: | |
gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY) | |
else: | |
gray = image_array | |
return pytesseract.image_to_string(gray, config=r'--oem 1 --psm 6') | |
except Exception as e2: | |
print(f"Fallback OCR Error: {e2}") | |
return f"Error extracting text: {str(e)}" | |
# Classify the math problem using Gemini 1.5 Flash | |
def classify_with_gemini_flash(math_problem): | |
"""Classify the math problem using Gemini model""" | |
try: | |
model = genai.GenerativeModel( | |
model_name=CLASSIFICATION_MODEL, | |
generation_config={ | |
"temperature": 0.1, | |
"top_p": 0.95, | |
"max_output_tokens": 150, | |
"response_mime_type": "application/json", | |
} | |
) | |
prompt = f""" | |
Task: Classify the following math problem. | |
PROBLEM: {math_problem} | |
Classify this math problem according to: | |
1. Primary category (e.g., Algebra, Calculus, Geometry, Trigonometry, Statistics, Number Theory) | |
2. Specific subtopic (e.g., Linear Equations, Derivatives, Integrals, Probability) | |
3. Difficulty level (Basic, Intermediate, Advanced) | |
4. Key concepts involved | |
Format the response as a JSON object with the fields: "category", "subtopic", "difficulty", "key_concepts". | |
""" | |
response = model.generate_content(prompt) | |
try: | |
classification = json.loads(response.text) | |
return classification | |
except json.JSONDecodeError: | |
print(f"JSON Decode Error: Unable to parse response: {response.text}") | |
return { | |
"category": "Unknown", | |
"subtopic": "Unknown", | |
"difficulty": "Unknown", | |
"key_concepts": ["Unknown"] | |
} | |
except Exception as e: | |
print(f"Classification Error: {e}") | |
print(traceback.format_exc()) | |
return { | |
"category": "Error", | |
"subtopic": "Error", | |
"difficulty": "Error", | |
"key_concepts": [f"Error: {str(e)}"] | |
} | |
# Solve the math problem using Gemini model | |
def solve_with_gemini_pro(math_problem, classification): | |
"""Solve the math problem using Gemini model""" | |
try: | |
model = genai.GenerativeModel( | |
model_name=SOLUTION_MODEL, | |
generation_config={ | |
"temperature": 0.2, | |
"top_p": 0.9, | |
"max_output_tokens": 1000, | |
} | |
) | |
# Ensure classification has the required fields with fallbacks | |
if not isinstance(classification, dict): | |
classification = { | |
"category": "Unknown", | |
"subtopic": "Unknown", | |
"difficulty": "Unknown", | |
"key_concepts": ["Unknown"] | |
} | |
for field in ["category", "subtopic", "difficulty"]: | |
if field not in classification or not classification[field]: | |
classification[field] = "Unknown" | |
if "key_concepts" not in classification or not classification["key_concepts"]: | |
classification["key_concepts"] = ["Unknown"] | |
# Format key concepts as a string | |
if isinstance(classification["key_concepts"], list): | |
key_concepts = ", ".join(classification["key_concepts"]) | |
else: | |
key_concepts = str(classification["key_concepts"]) | |
prompt = f""" | |
Task: Solve the following math problem with clear step-by-step explanations. | |
PROBLEM: {math_problem} | |
CLASSIFICATION: | |
- Category: {classification["category"]} | |
- Subtopic: {classification["subtopic"]} | |
- Difficulty: {classification["difficulty"]} | |
- Key Concepts: {key_concepts} | |
Provide a complete solution following these guidelines: | |
1. Start with an overview of the approach | |
2. Break down the problem into clear, logical steps | |
3. Explain each step thoroughly, mentioning the mathematical principles applied | |
4. Show all work and calculations | |
5. Verify the answer if possible | |
6. Summarize the key takeaway from this problem | |
Format the solution to be readable on a mobile device, with appropriate spacing between steps. | |
""" | |
response = model.generate_content(prompt) | |
return response.text | |
except Exception as e: | |
print(f"Solution Error: {e}") | |
print(traceback.format_exc()) | |
return f"Error generating solution: {str(e)}" | |
# Explain the solution in more detail | |
def explain_solution(math_problem, solution): | |
"""Provide a more detailed explanation of the solution""" | |
try: | |
print(f"Generating detailed explanation...") | |
model = genai.GenerativeModel( | |
model_name=EXPLANATION_MODEL, | |
generation_config={ | |
"temperature": 0.3, | |
"top_p": 0.95, | |
"max_output_tokens": 1500, | |
} | |
) | |
prompt = f""" | |
Task: Provide a more detailed explanation of the solution to this math problem. | |
PROBLEM: {math_problem} | |
SOLUTION: {solution} | |
Provide a more comprehensive explanation that: | |
1. Breaks down complex steps into simpler components | |
2. Explains the underlying mathematical principles in depth | |
3. Connects this problem to fundamental concepts | |
4. Offers visual or intuitive ways to understand the concepts | |
5. Highlights common mistakes students make with this type of problem | |
6. Suggests alternative solution approaches if applicable | |
Make the explanation accessible to a student who is struggling with this topic. | |
""" | |
response = model.generate_content(prompt) | |
return response.text | |
except Exception as e: | |
print(f"Explanation Error: {e}") | |
print(traceback.format_exc()) | |
return f"Error generating explanation: {str(e)}" | |
# Generate similar practice problems | |
def generate_similar_problems(math_problem, classification): | |
"""Generate similar practice math problems""" | |
try: | |
print(f"Generating similar problems...") | |
model = genai.GenerativeModel( | |
model_name=SIMILAR_MODEL, | |
generation_config={ | |
"temperature": 0.7, | |
"top_p": 0.95, | |
"max_output_tokens": 1000, | |
} | |
) | |
# Prepare classification string | |
classification_str = json.dumps(classification, indent=2) | |
prompt = f""" | |
Task: Generate similar practice math problems based on the following problem. | |
ORIGINAL PROBLEM: {math_problem} | |
CLASSIFICATION: {classification_str} | |
Generate 3 similar practice problems that: | |
1. Cover the same mathematical concepts and principles | |
2. Vary in difficulty (one easier, one similar, one harder) | |
3. Use different numerical values or variables | |
4. Test the same underlying skills | |
For each problem: | |
- Provide the complete problem statement | |
- Include a brief hint for solving it | |
- Provide the correct answer (but not the full solution) | |
Format as three separate problems with clear numbering. | |
""" | |
response = model.generate_content(prompt) | |
return response.text | |
except Exception as e: | |
print(f"Similar Problems Error: {e}") | |
print(traceback.format_exc()) | |
return f"Error generating similar problems: {str(e)}" | |
# Main function for processing images | |
def process_image(image, progress=gr.Progress()): | |
"""Main processing pipeline for the NerdAI app""" | |
try: | |
if image is None: | |
return None, "No image uploaded", "No image uploaded", "No image uploaded", "No image uploaded" | |
progress(0, desc="Starting processing...") | |
# Step 1: Extract text with Gemini model | |
progress(0.4, desc="Extracting text with Gemini Pro Vision...") | |
extracted_text = extract_text_with_gemini(image) | |
if not extracted_text or extracted_text.strip() == "": | |
return image, "No text was extracted from the image. Please try a clearer image.", "No text extracted", "No text was extracted from the image.", "" | |
# Step 2: Classify with Gemini model | |
progress(0.6, desc=f"Classifying problem with {CLASSIFICATION_MODEL}...") | |
classification = classify_with_gemini_flash(extracted_text) | |
classification_json = json.dumps(classification, indent=2) | |
# Step 3: Solve with Gemini model | |
progress(0.8, desc=f"Solving problem with {SOLUTION_MODEL}...") | |
solution = solve_with_gemini_pro(extracted_text, classification) | |
# Complete | |
progress(1.0, desc="Processing complete") | |
return image, extracted_text, classification_json, solution, extracted_text | |
except Exception as e: | |
print(f"Process Image Error: {e}") | |
print(traceback.format_exc()) | |
return None, f"Error processing image: {str(e)}", "Error", "Error", "" | |
# Create the Gradio interface | |
with gr.Blocks(title="NerdAI Math Problem Solver") as demo: | |
gr.Markdown("# NerdAI Math Problem Solver") | |
gr.Markdown("Upload an image of a math problem to get a step-by-step solution") | |
# Store state variables | |
extracted_text_state = gr.State("") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Input section | |
input_image = gr.Image(label="Upload Math Problem Image", type="pil") | |
process_btn = gr.Button("Process Image", variant="primary") | |
with gr.Column(scale=1): | |
# Processed image output | |
processed_image = gr.Image(label="Processed Image") | |
with gr.Row(): | |
# Text extraction output | |
extracted_text = gr.Textbox(label="Extracted Text", lines=3) | |
with gr.Row(): | |
# Classification output | |
classification = gr.Textbox(label=f"Problem Classification", lines=6) | |
with gr.Row(): | |
# Solution output | |
solution = gr.Markdown(label="Solution") | |
with gr.Row(): | |
explain_btn = gr.Button("Explain It", variant="secondary") | |
similar_btn = gr.Button("Similar Questions", variant="secondary") | |
with gr.Row(): | |
# Additional outputs | |
with gr.Tabs(): | |
with gr.TabItem("Detailed Explanation"): | |
explanation = gr.Markdown() | |
with gr.TabItem("Similar Practice Problems"): | |
similar_problems = gr.Markdown() | |
# Event handlers for the buttons | |
def explain_button_handler(math_problem, solution_text): | |
"""Handler for Explain It button""" | |
print(f"Explain button clicked") | |
if not math_problem or math_problem == "No image uploaded": | |
return "Please process an image first" | |
return explain_solution(math_problem, solution_text) | |
def similar_button_handler(math_problem, classification_json): | |
"""Handler for Similar Questions button""" | |
print(f"Similar button clicked") | |
if not math_problem or math_problem == "No image uploaded": | |
return "Please process an image first" | |
try: | |
# Parse classification JSON | |
try: | |
classification = json.loads(classification_json) | |
except: | |
classification = { | |
"category": "Unknown", | |
"subtopic": "Unknown", | |
"difficulty": "Unknown", | |
"key_concepts": ["Unknown"] | |
} | |
# Validate classification | |
if not isinstance(classification, dict): | |
classification = { | |
"category": "Unknown", | |
"subtopic": "Unknown", | |
"difficulty": "Unknown", | |
"key_concepts": ["Unknown"] | |
} | |
# Ensure fields exist | |
for field in ["category", "subtopic", "difficulty"]: | |
if field not in classification or not classification[field]: | |
classification[field] = "Unknown" | |
if "key_concepts" not in classification or not classification["key_concepts"]: | |
classification["key_concepts"] = ["Unknown"] | |
return generate_similar_problems(math_problem, classification) | |
except Exception as e: | |
print(f"Error in similar_button_handler: {e}") | |
print(traceback.format_exc()) | |
return f"Error generating similar problems: {str(e)}" | |
# Set up event handlers | |
process_btn.click( | |
fn=process_image, | |
inputs=[input_image], | |
outputs=[processed_image, extracted_text, classification, solution, extracted_text_state] | |
) | |
explain_btn.click( | |
fn=explain_button_handler, | |
inputs=[extracted_text_state, solution], | |
outputs=explanation | |
) | |
similar_btn.click( | |
fn=similar_button_handler, | |
inputs=[extracted_text_state, classification], | |
outputs=similar_problems | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |