File size: 5,894 Bytes
e2e8ffc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bca9cda
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import torch
from pathlib import Path
from transformers import CLIPProcessor, CLIPModel
from PIL import Image, ImageDraw
import pytesseract
import requests
import os 
from llm import inference, upload_image
from fastapi.responses import FileResponse, JSONResponse

import re

from io import BytesIO

cropped_images_dir = "cropped_images"
os.makedirs(cropped_images_dir, exist_ok=True)

# Load YOLO model
class YOLOModel:
    def __init__(self, model_path="yolov5s.pt"):
        """
        Initialize the YOLO model. Downloads YOLOv5 pretrained model if not available.
        """
        torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
        self.model = torch.hub.load("ultralytics/yolov5", "custom", path=model_path, force_reload=True)


    def predict_clip(self, image, brand_names):
        """
        Predict the most probable brand using CLIP.
        """
        inputs = self.clip_processor(
            text=brand_names,
            images=image,
            return_tensors="pt",
            padding=True
        )
        # print(f'Inputs to clip processor:{inputs}')
        outputs = self.clip_model(**inputs)
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1)  # Convert logits to probabilities
        best_idx = probs.argmax().item()
        return brand_names[best_idx], probs[0, best_idx].item()


    def predict_text(self, image):
        try:
            # Convert image to grayscale
            grayscale = image.convert('L')
            
            # Perform OCR using pytesseract
            text = pytesseract.image_to_string(grayscale)
            
            # Return the stripped text if successful
            return text.strip()
        except Exception as e:
            # Log the error for debugging purposes
            print(f"Error during text prediction: {e}")
            
            # Return an empty string if OCR fails
            return ""


    def predict(self, image_path):
        """
        Run YOLO inference on an image.

        :param image_path: Path to the input image
        :return: List of predictions with labels and bounding boxes
        """
        results = self.model(image_path)
        image = Image.open(image_path).convert("RGB")
        draw = ImageDraw.Draw(image)
        predictions = results.pandas().xyxy[0]  # Get predictions as pandas DataFrame
        print(f'YOLO predictions:\n\n{predictions}')
        
        
        output = []
        file_responses = []
        
        
        for idx, row in predictions.iterrows():
            category = row['name']
            confidence = row['confidence']
            bbox = [row["xmin"], row["ymin"], row["xmax"], row["ymax"]]

            # Crop the detected region
            cropped_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
            cropped_image_path = os.path.join(cropped_images_dir, f"crop_{idx}.jpg")
            cropped_image.save(cropped_image_path, "JPEG")

            # uploading to cloud for getting URL to pass into LLM
            print(f'Uploading now to image url')
            image_url = upload_image.upload_image_to_imgbb(cropped_image_path)
            print(f'Image URL received as{image_url}')
            # inferencing llm for possible brands
            result_llms = inference.get_name(image_url, category)

            detected_text = self.predict_text(cropped_image)
            print(f'Details:{detected_text}')
            print(f'Predicted brand: {result_llms["model"]}')
            # Draw bounding box and label on the image
            draw.rectangle(bbox, outline="red", width=3)
            draw.text(
                (bbox[0], bbox[1] - 10),
                f'{result_llms["brand"]})',
                fill="red"
            )

            cropped_image_io = BytesIO()
            cropped_image.save(cropped_image_io, format="JPEG")
            cropped_image_io.seek(0)

            # Append result
            output.append({
                "category": category,
                "bbox": bbox,
                "confidence": confidence,
                "category_llm":result_llms["brand"],
                "predicted_brand": result_llms["model"],
                # "clip_confidence": clip_confidence,
                "price":result_llms["price"],
                "details":result_llms["description"],
                "detected_text":detected_text,
                "image_path":cropped_image_path,
                "image_url":image_url,
            })

            # file_responses.append(f"/download_cropped_image/{idx}")

            valid_indices = set(range(len(predictions)))

            # Iterate over all files in the directory
            for filename in os.listdir(cropped_images_dir):
                # Check if the filename matches the pattern for cropped images
                if filename.startswith("crop_") and filename.endswith(".jpg"):
                    # Extract the index from the filename
                    try:
                        file_idx = int(filename.split("_")[1].split(".")[0])
                        if file_idx not in valid_indices:
                            # Delete the file if its index is not valid
                            file_path = os.path.join(cropped_images_dir, filename)
                            os.remove(file_path)
                            print(f"Deleted excess file: {filename}")
                    except ValueError:
                        # Skip files that don't match the pattern
                        continue

        return output
    #     return JSONResponse(
    #     content={
    #         "metadata": results,
    #         "cropped_image_urls": [
    #             f"/download_cropped_image/{idx}" for idx in range(len(file_responses))
    #         ],
    #     }
    # )
        # return {"metadata": results, "cropped_image_urls": file_responses}