ChaseHan's picture
Update app.py
1ff383d verified
raw
history blame
6.78 kB
import gradio as gr
import cv2
import numpy as np
import os
import requests
import json
from PIL import Image
import io
import base64
from openai import OpenAI
# API endpoints
YOLO_API_ENDPOINT = "https://api.example.com/yolo" # Replace with actual YOLO API endpoint
# Qwen API configuration
QWEN_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
QWEN_MODEL_ID = "qwen2.5-vl-3b-instruct"
def encode_image(image_array):
"""
Encode numpy array image to base64 string.
Args:
image_array: numpy array of the image
Returns:
base64 encoded string of the image
"""
# Convert numpy array to PIL Image
pil_image = Image.fromarray(image_array)
# Convert PIL Image to bytes
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
# Encode to base64
return base64.b64encode(img_byte_arr).decode("utf-8")
def detect_layout(image):
"""
Perform layout detection on the uploaded image using YOLO API.
Args:
image: The uploaded image as a numpy array
Returns:
annotated_image: Image with detection boxes
layout_info: Layout detection results
"""
if image is None:
return None, "Error: No image uploaded."
# Convert numpy array to PIL Image
pil_image = Image.fromarray(image)
# Convert PIL Image to bytes for API request
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
# Prepare API request
files = {'image': ('image.png', img_byte_arr, 'image/png')}
try:
# Call YOLO API
response = requests.post(YOLO_API_ENDPOINT, files=files)
response.raise_for_status()
detection_results = response.json()
# Create a copy of the image for visualization
annotated_image = image.copy()
# Draw detection results
for detection in detection_results:
x1, y1, x2, y2 = detection['bbox']
cls_name = detection['class']
conf = detection['confidence']
# Generate a color for each class
color = tuple(np.random.randint(0, 255, 3).tolist())
# Draw bounding box and label
cv2.rectangle(annotated_image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
label = f'{cls_name} {conf:.2f}'
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(annotated_image, (int(x1), int(y1)-label_height-5), (int(x1)+label_width, int(y1)), color, -1)
cv2.putText(annotated_image, label, (int(x1), int(y1)-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# Format layout information for Qwen
layout_info = json.dumps(detection_results, indent=2)
return annotated_image, layout_info
except Exception as e:
return None, f"Error during layout detection: {str(e)}"
def qa_about_layout(image, question, layout_info, api_key):
"""
Answer questions about the layout using Qwen2.5-VL API.
Args:
image: The uploaded image
question: User's question about the layout
layout_info: Layout detection results from YOLO
api_key: User's Qwen API key
Returns:
answer: Qwen's answer to the question
"""
if image is None or not question:
return "Please upload an image and ask a question."
if not layout_info:
return "No layout information available. Please detect layout first."
if not api_key:
return "Please enter your Qwen API key."
try:
# Encode image to base64
base64_image = encode_image(image)
# Initialize OpenAI client for Qwen API
client = OpenAI(
api_key=api_key,
base_url=QWEN_BASE_URL,
)
# Prepare system prompt with layout information
system_prompt = f"""You are a helpful assistant specialized in analyzing document layouts.
The following layout information has been detected in the image:
{layout_info}
Please answer questions about the layout based on this information and the image."""
# Prepare messages for API call
messages = [
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}]
},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
},
{"type": "text", "text": question},
],
}
]
# Call Qwen API
completion = client.chat.completions.create(
model=QWEN_MODEL_ID,
messages=messages,
)
return completion.choices[0].message.content
except Exception as e:
return f"Error during QA: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Latex2Layout QA System") as demo:
gr.Markdown("# Latex2Layout QA System")
gr.Markdown("Upload an image, detect layout elements, and ask questions about the layout.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Upload Image", type="numpy")
detect_btn = gr.Button("Detect Layout")
gr.Markdown("**Tip**: Upload a clear image for optimal detection results.")
with gr.Column(scale=1):
output_image = gr.Image(label="Detection Results")
layout_info = gr.Textbox(label="Layout Information", lines=10)
with gr.Row():
with gr.Column(scale=1):
api_key_input = gr.Textbox(
label="Qwen API Key",
placeholder="Enter your Qwen API key here",
type="password"
)
question_input = gr.Textbox(label="Ask a question about the layout")
qa_btn = gr.Button("Ask Question")
with gr.Column(scale=1):
answer_output = gr.Textbox(label="Answer", lines=5)
# Event handlers
detect_btn.click(
fn=detect_layout,
inputs=[input_image],
outputs=[output_image, layout_info]
)
qa_btn.click(
fn=qa_about_layout,
inputs=[input_image, question_input, layout_info, api_key_input],
outputs=[answer_output]
)
# Launch the application
if __name__ == "__main__":
demo.launch()