whatisit / model.py
root-sajjan's picture
edited tesseract error handling
e2e8ffc verified
import torch
from pathlib import Path
from transformers import CLIPProcessor, CLIPModel
from PIL import Image, ImageDraw
import pytesseract
import requests
import os
from llm import inference, upload_image
from fastapi.responses import FileResponse, JSONResponse
import re
from io import BytesIO
cropped_images_dir = "cropped_images"
os.makedirs(cropped_images_dir, exist_ok=True)
# Load YOLO model
class YOLOModel:
def __init__(self, model_path="yolov5s.pt"):
"""
Initialize the YOLO model. Downloads YOLOv5 pretrained model if not available.
"""
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
self.model = torch.hub.load("ultralytics/yolov5", "custom", path=model_path, force_reload=True)
def predict_clip(self, image, brand_names):
"""
Predict the most probable brand using CLIP.
"""
inputs = self.clip_processor(
text=brand_names,
images=image,
return_tensors="pt",
padding=True
)
# print(f'Inputs to clip processor:{inputs}')
outputs = self.clip_model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
best_idx = probs.argmax().item()
return brand_names[best_idx], probs[0, best_idx].item()
def predict_text(self, image):
try:
# Convert image to grayscale
grayscale = image.convert('L')
# Perform OCR using pytesseract
text = pytesseract.image_to_string(grayscale)
# Return the stripped text if successful
return text.strip()
except Exception as e:
# Log the error for debugging purposes
print(f"Error during text prediction: {e}")
# Return an empty string if OCR fails
return ""
def predict(self, image_path):
"""
Run YOLO inference on an image.
:param image_path: Path to the input image
:return: List of predictions with labels and bounding boxes
"""
results = self.model(image_path)
image = Image.open(image_path).convert("RGB")
draw = ImageDraw.Draw(image)
predictions = results.pandas().xyxy[0] # Get predictions as pandas DataFrame
print(f'YOLO predictions:\n\n{predictions}')
output = []
file_responses = []
for idx, row in predictions.iterrows():
category = row['name']
confidence = row['confidence']
bbox = [row["xmin"], row["ymin"], row["xmax"], row["ymax"]]
# Crop the detected region
cropped_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
cropped_image_path = os.path.join(cropped_images_dir, f"crop_{idx}.jpg")
cropped_image.save(cropped_image_path, "JPEG")
# uploading to cloud for getting URL to pass into LLM
print(f'Uploading now to image url')
image_url = upload_image.upload_image_to_imgbb(cropped_image_path)
print(f'Image URL received as{image_url}')
# inferencing llm for possible brands
result_llms = inference.get_name(image_url, category)
detected_text = self.predict_text(cropped_image)
print(f'Details:{detected_text}')
print(f'Predicted brand: {result_llms["model"]}')
# Draw bounding box and label on the image
draw.rectangle(bbox, outline="red", width=3)
draw.text(
(bbox[0], bbox[1] - 10),
f'{result_llms["brand"]})',
fill="red"
)
cropped_image_io = BytesIO()
cropped_image.save(cropped_image_io, format="JPEG")
cropped_image_io.seek(0)
# Append result
output.append({
"category": category,
"bbox": bbox,
"confidence": confidence,
"category_llm":result_llms["brand"],
"predicted_brand": result_llms["model"],
# "clip_confidence": clip_confidence,
"price":result_llms["price"],
"details":result_llms["description"],
"detected_text":detected_text,
"image_path":cropped_image_path,
"image_url":image_url,
})
# file_responses.append(f"/download_cropped_image/{idx}")
valid_indices = set(range(len(predictions)))
# Iterate over all files in the directory
for filename in os.listdir(cropped_images_dir):
# Check if the filename matches the pattern for cropped images
if filename.startswith("crop_") and filename.endswith(".jpg"):
# Extract the index from the filename
try:
file_idx = int(filename.split("_")[1].split(".")[0])
if file_idx not in valid_indices:
# Delete the file if its index is not valid
file_path = os.path.join(cropped_images_dir, filename)
os.remove(file_path)
print(f"Deleted excess file: {filename}")
except ValueError:
# Skip files that don't match the pattern
continue
return output
# return JSONResponse(
# content={
# "metadata": results,
# "cropped_image_urls": [
# f"/download_cropped_image/{idx}" for idx in range(len(file_responses))
# ],
# }
# )
# return {"metadata": results, "cropped_image_urls": file_responses}