Spaces:
Sleeping
Sleeping
| import torch | |
| from pathlib import Path | |
| from transformers import CLIPProcessor, CLIPModel | |
| from PIL import Image, ImageDraw | |
| import pytesseract | |
| import requests | |
| import os | |
| from llm import inference, upload_image | |
| from fastapi.responses import FileResponse, JSONResponse | |
| import re | |
| from io import BytesIO | |
| cropped_images_dir = "cropped_images" | |
| os.makedirs(cropped_images_dir, exist_ok=True) | |
| # Load YOLO model | |
| class YOLOModel: | |
| def __init__(self, model_path="yolov5s.pt"): | |
| """ | |
| Initialize the YOLO model. Downloads YOLOv5 pretrained model if not available. | |
| """ | |
| torch.hub._validate_not_a_forked_repo=lambda a,b,c: True | |
| self.model = torch.hub.load("ultralytics/yolov5", "custom", path=model_path, force_reload=True) | |
| def predict_clip(self, image, brand_names): | |
| """ | |
| Predict the most probable brand using CLIP. | |
| """ | |
| inputs = self.clip_processor( | |
| text=brand_names, | |
| images=image, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| # print(f'Inputs to clip processor:{inputs}') | |
| outputs = self.clip_model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities | |
| best_idx = probs.argmax().item() | |
| return brand_names[best_idx], probs[0, best_idx].item() | |
| def predict_text(self, image): | |
| try: | |
| # Convert image to grayscale | |
| grayscale = image.convert('L') | |
| # Perform OCR using pytesseract | |
| text = pytesseract.image_to_string(grayscale) | |
| # Return the stripped text if successful | |
| return text.strip() | |
| except Exception as e: | |
| # Log the error for debugging purposes | |
| print(f"Error during text prediction: {e}") | |
| # Return an empty string if OCR fails | |
| return "" | |
| def predict(self, image_path): | |
| """ | |
| Run YOLO inference on an image. | |
| :param image_path: Path to the input image | |
| :return: List of predictions with labels and bounding boxes | |
| """ | |
| results = self.model(image_path) | |
| image = Image.open(image_path).convert("RGB") | |
| draw = ImageDraw.Draw(image) | |
| predictions = results.pandas().xyxy[0] # Get predictions as pandas DataFrame | |
| print(f'YOLO predictions:\n\n{predictions}') | |
| output = [] | |
| file_responses = [] | |
| for idx, row in predictions.iterrows(): | |
| category = row['name'] | |
| confidence = row['confidence'] | |
| bbox = [row["xmin"], row["ymin"], row["xmax"], row["ymax"]] | |
| # Crop the detected region | |
| cropped_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3])) | |
| cropped_image_path = os.path.join(cropped_images_dir, f"crop_{idx}.jpg") | |
| cropped_image.save(cropped_image_path, "JPEG") | |
| # uploading to cloud for getting URL to pass into LLM | |
| print(f'Uploading now to image url') | |
| image_url = upload_image.upload_image_to_imgbb(cropped_image_path) | |
| print(f'Image URL received as{image_url}') | |
| # inferencing llm for possible brands | |
| result_llms = inference.get_name(image_url, category) | |
| detected_text = self.predict_text(cropped_image) | |
| print(f'Details:{detected_text}') | |
| print(f'Predicted brand: {result_llms["model"]}') | |
| # Draw bounding box and label on the image | |
| draw.rectangle(bbox, outline="red", width=3) | |
| draw.text( | |
| (bbox[0], bbox[1] - 10), | |
| f'{result_llms["brand"]})', | |
| fill="red" | |
| ) | |
| cropped_image_io = BytesIO() | |
| cropped_image.save(cropped_image_io, format="JPEG") | |
| cropped_image_io.seek(0) | |
| # Append result | |
| output.append({ | |
| "category": category, | |
| "bbox": bbox, | |
| "confidence": confidence, | |
| "category_llm":result_llms["brand"], | |
| "predicted_brand": result_llms["model"], | |
| # "clip_confidence": clip_confidence, | |
| "price":result_llms["price"], | |
| "details":result_llms["description"], | |
| "detected_text":detected_text, | |
| "image_path":cropped_image_path, | |
| "image_url":image_url, | |
| }) | |
| # file_responses.append(f"/download_cropped_image/{idx}") | |
| valid_indices = set(range(len(predictions))) | |
| # Iterate over all files in the directory | |
| for filename in os.listdir(cropped_images_dir): | |
| # Check if the filename matches the pattern for cropped images | |
| if filename.startswith("crop_") and filename.endswith(".jpg"): | |
| # Extract the index from the filename | |
| try: | |
| file_idx = int(filename.split("_")[1].split(".")[0]) | |
| if file_idx not in valid_indices: | |
| # Delete the file if its index is not valid | |
| file_path = os.path.join(cropped_images_dir, filename) | |
| os.remove(file_path) | |
| print(f"Deleted excess file: {filename}") | |
| except ValueError: | |
| # Skip files that don't match the pattern | |
| continue | |
| return output | |
| # return JSONResponse( | |
| # content={ | |
| # "metadata": results, | |
| # "cropped_image_urls": [ | |
| # f"/download_cropped_image/{idx}" for idx in range(len(file_responses)) | |
| # ], | |
| # } | |
| # ) | |
| # return {"metadata": results, "cropped_image_urls": file_responses} | |