Spaces:
Sleeping
Sleeping
File size: 5,894 Bytes
e2e8ffc bca9cda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import torch
from pathlib import Path
from transformers import CLIPProcessor, CLIPModel
from PIL import Image, ImageDraw
import pytesseract
import requests
import os
from llm import inference, upload_image
from fastapi.responses import FileResponse, JSONResponse
import re
from io import BytesIO
cropped_images_dir = "cropped_images"
os.makedirs(cropped_images_dir, exist_ok=True)
# Load YOLO model
class YOLOModel:
def __init__(self, model_path="yolov5s.pt"):
"""
Initialize the YOLO model. Downloads YOLOv5 pretrained model if not available.
"""
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
self.model = torch.hub.load("ultralytics/yolov5", "custom", path=model_path, force_reload=True)
def predict_clip(self, image, brand_names):
"""
Predict the most probable brand using CLIP.
"""
inputs = self.clip_processor(
text=brand_names,
images=image,
return_tensors="pt",
padding=True
)
# print(f'Inputs to clip processor:{inputs}')
outputs = self.clip_model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
best_idx = probs.argmax().item()
return brand_names[best_idx], probs[0, best_idx].item()
def predict_text(self, image):
try:
# Convert image to grayscale
grayscale = image.convert('L')
# Perform OCR using pytesseract
text = pytesseract.image_to_string(grayscale)
# Return the stripped text if successful
return text.strip()
except Exception as e:
# Log the error for debugging purposes
print(f"Error during text prediction: {e}")
# Return an empty string if OCR fails
return ""
def predict(self, image_path):
"""
Run YOLO inference on an image.
:param image_path: Path to the input image
:return: List of predictions with labels and bounding boxes
"""
results = self.model(image_path)
image = Image.open(image_path).convert("RGB")
draw = ImageDraw.Draw(image)
predictions = results.pandas().xyxy[0] # Get predictions as pandas DataFrame
print(f'YOLO predictions:\n\n{predictions}')
output = []
file_responses = []
for idx, row in predictions.iterrows():
category = row['name']
confidence = row['confidence']
bbox = [row["xmin"], row["ymin"], row["xmax"], row["ymax"]]
# Crop the detected region
cropped_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
cropped_image_path = os.path.join(cropped_images_dir, f"crop_{idx}.jpg")
cropped_image.save(cropped_image_path, "JPEG")
# uploading to cloud for getting URL to pass into LLM
print(f'Uploading now to image url')
image_url = upload_image.upload_image_to_imgbb(cropped_image_path)
print(f'Image URL received as{image_url}')
# inferencing llm for possible brands
result_llms = inference.get_name(image_url, category)
detected_text = self.predict_text(cropped_image)
print(f'Details:{detected_text}')
print(f'Predicted brand: {result_llms["model"]}')
# Draw bounding box and label on the image
draw.rectangle(bbox, outline="red", width=3)
draw.text(
(bbox[0], bbox[1] - 10),
f'{result_llms["brand"]})',
fill="red"
)
cropped_image_io = BytesIO()
cropped_image.save(cropped_image_io, format="JPEG")
cropped_image_io.seek(0)
# Append result
output.append({
"category": category,
"bbox": bbox,
"confidence": confidence,
"category_llm":result_llms["brand"],
"predicted_brand": result_llms["model"],
# "clip_confidence": clip_confidence,
"price":result_llms["price"],
"details":result_llms["description"],
"detected_text":detected_text,
"image_path":cropped_image_path,
"image_url":image_url,
})
# file_responses.append(f"/download_cropped_image/{idx}")
valid_indices = set(range(len(predictions)))
# Iterate over all files in the directory
for filename in os.listdir(cropped_images_dir):
# Check if the filename matches the pattern for cropped images
if filename.startswith("crop_") and filename.endswith(".jpg"):
# Extract the index from the filename
try:
file_idx = int(filename.split("_")[1].split(".")[0])
if file_idx not in valid_indices:
# Delete the file if its index is not valid
file_path = os.path.join(cropped_images_dir, filename)
os.remove(file_path)
print(f"Deleted excess file: {filename}")
except ValueError:
# Skip files that don't match the pattern
continue
return output
# return JSONResponse(
# content={
# "metadata": results,
# "cropped_image_urls": [
# f"/download_cropped_image/{idx}" for idx in range(len(file_responses))
# ],
# }
# )
# return {"metadata": results, "cropped_image_urls": file_responses}
|