Spaces:
Sleeping
Sleeping
import torch | |
from pathlib import Path | |
from transformers import CLIPProcessor, CLIPModel | |
from PIL import Image, ImageDraw | |
import pytesseract | |
import requests | |
import os | |
from llm import inference, upload_image | |
from fastapi.responses import FileResponse, JSONResponse | |
import re | |
from io import BytesIO | |
cropped_images_dir = "cropped_images" | |
os.makedirs(cropped_images_dir, exist_ok=True) | |
# Load YOLO model | |
class YOLOModel: | |
def __init__(self, model_path="yolov5s.pt"): | |
""" | |
Initialize the YOLO model. Downloads YOLOv5 pretrained model if not available. | |
""" | |
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True | |
self.model = torch.hub.load("ultralytics/yolov5", "custom", path=model_path, force_reload=True) | |
def predict_clip(self, image, brand_names): | |
""" | |
Predict the most probable brand using CLIP. | |
""" | |
inputs = self.clip_processor( | |
text=brand_names, | |
images=image, | |
return_tensors="pt", | |
padding=True | |
) | |
# print(f'Inputs to clip processor:{inputs}') | |
outputs = self.clip_model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities | |
best_idx = probs.argmax().item() | |
return brand_names[best_idx], probs[0, best_idx].item() | |
def predict_text(self, image): | |
try: | |
# Convert image to grayscale | |
grayscale = image.convert('L') | |
# Perform OCR using pytesseract | |
text = pytesseract.image_to_string(grayscale) | |
# Return the stripped text if successful | |
return text.strip() | |
except Exception as e: | |
# Log the error for debugging purposes | |
print(f"Error during text prediction: {e}") | |
# Return an empty string if OCR fails | |
return "" | |
def predict(self, image_path): | |
""" | |
Run YOLO inference on an image. | |
:param image_path: Path to the input image | |
:return: List of predictions with labels and bounding boxes | |
""" | |
results = self.model(image_path) | |
image = Image.open(image_path).convert("RGB") | |
draw = ImageDraw.Draw(image) | |
predictions = results.pandas().xyxy[0] # Get predictions as pandas DataFrame | |
print(f'YOLO predictions:\n\n{predictions}') | |
output = [] | |
file_responses = [] | |
for idx, row in predictions.iterrows(): | |
category = row['name'] | |
confidence = row['confidence'] | |
bbox = [row["xmin"], row["ymin"], row["xmax"], row["ymax"]] | |
# Crop the detected region | |
cropped_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3])) | |
cropped_image_path = os.path.join(cropped_images_dir, f"crop_{idx}.jpg") | |
cropped_image.save(cropped_image_path, "JPEG") | |
# uploading to cloud for getting URL to pass into LLM | |
print(f'Uploading now to image url') | |
image_url = upload_image.upload_image_to_imgbb(cropped_image_path) | |
print(f'Image URL received as{image_url}') | |
# inferencing llm for possible brands | |
result_llms = inference.get_name(image_url, category) | |
detected_text = self.predict_text(cropped_image) | |
print(f'Details:{detected_text}') | |
print(f'Predicted brand: {result_llms["model"]}') | |
# Draw bounding box and label on the image | |
draw.rectangle(bbox, outline="red", width=3) | |
draw.text( | |
(bbox[0], bbox[1] - 10), | |
f'{result_llms["brand"]})', | |
fill="red" | |
) | |
cropped_image_io = BytesIO() | |
cropped_image.save(cropped_image_io, format="JPEG") | |
cropped_image_io.seek(0) | |
# Append result | |
output.append({ | |
"category": category, | |
"bbox": bbox, | |
"confidence": confidence, | |
"category_llm":result_llms["brand"], | |
"predicted_brand": result_llms["model"], | |
# "clip_confidence": clip_confidence, | |
"price":result_llms["price"], | |
"details":result_llms["description"], | |
"detected_text":detected_text, | |
"image_path":cropped_image_path, | |
"image_url":image_url, | |
}) | |
# file_responses.append(f"/download_cropped_image/{idx}") | |
valid_indices = set(range(len(predictions))) | |
# Iterate over all files in the directory | |
for filename in os.listdir(cropped_images_dir): | |
# Check if the filename matches the pattern for cropped images | |
if filename.startswith("crop_") and filename.endswith(".jpg"): | |
# Extract the index from the filename | |
try: | |
file_idx = int(filename.split("_")[1].split(".")[0]) | |
if file_idx not in valid_indices: | |
# Delete the file if its index is not valid | |
file_path = os.path.join(cropped_images_dir, filename) | |
os.remove(file_path) | |
print(f"Deleted excess file: {filename}") | |
except ValueError: | |
# Skip files that don't match the pattern | |
continue | |
return output | |
# return JSONResponse( | |
# content={ | |
# "metadata": results, | |
# "cropped_image_urls": [ | |
# f"/download_cropped_image/{idx}" for idx in range(len(file_responses)) | |
# ], | |
# } | |
# ) | |
# return {"metadata": results, "cropped_image_urls": file_responses} | |