import os import torch import numpy as np from PIL import Image import gradio as gr from huggingface_hub import hf_hub_download import spaces from torchvision import transforms import easyocr from transformers import CLIPProcessor, CLIPModel from huggingface_hub import hf_hub_download HF_TOKEN = os.environ.get("HF_TOKEN") model = None clip_processor = None # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') classes = ['.ipynb_checkpoints', '2-products-in-one-offer', '2-products-in-one-offer-+-coupon', 'an-offer-with-a-coupon', 'availability-of-additional-products', 'offers-with-a-preliminary-promotional-price', 'offers-with-an-additional-deposit-price', 'offers-with-an-additional-shipping', 'offers-with-dealtype-special_price', 'offers-with-different-sizes', 'offers-with-money_rebate', 'offers-with-percentage_rebate', 'offers-with-price-characteristic-(statt)', 'offers-with-price-characterization-(uvp)', 'offers-with-product-number-(sku)', 'offers-with-reward', 'offers-with-the-condition_-available-from-such-and-such-a-number', 'offers-with-the-old-price-crossed-out', 'regular', 'scene-with-multiple-offers-+-uvp-price-for-each-offers', 'several-products-in-one-offer-with-different-prices', 'simple-offers', 'stock-offers', 'stocks', 'travel-booklets', 'with-a-product-without-a-price', 'with-the-price-of-the-supplemental-bid'] # Custom CLIP-based Multimodal Classifier class CLIPMultimodalClassifier(torch.nn.Module): def __init__(self, num_classes): super(CLIPMultimodalClassifier, self).__init__() self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") self.fc = torch.nn.Linear(self.clip_model.config.projection_dim, num_classes) def forward(self, images, texts): image_features = self.clip_model.get_image_features(images) text_features = self.clip_model.get_text_features(texts) combined_features = (image_features + text_features) / 2 logits = self.fc(combined_features) return logits # Image preprocessing (resize and normalize) image_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # EasyOCR Reader for text extraction # ocr_reader = easyocr.Reader(['en', 'de'], gpu=False) # Supports English and German ocr_reader = easyocr.Reader(['en', 'de'], model_storage_directory="./") # model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") num_classes = len(classes) model = CLIPMultimodalClassifier(num_classes) clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # Inference function @spaces.GPU() def run_inference(image, model, clip_processor): # global model, clip_processor device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # # Initialize model if not already loaded # if model is None or clip_processor is None: # clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # model_path = hf_hub_download(repo_id="limitedonly41/offers_26", # filename="multi_train_best_model.pth", # use_auth_token=HF_TOKEN) # num_classes = len(classes) # model = CLIPMultimodalClassifier(num_classes).to(device) # model.load_state_dict(torch.load(model_path, map_location=device)) # model.eval() # clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # model_path = hf_hub_download(repo_id="limitedonly41/offers_26", # filename="multi_train_best_model.pth", # use_auth_token=HF_TOKEN) model_path = hf_hub_download(repo_id="limitedonly41/offers_26", filename="continued_training_model_5.pth", use_auth_token=HF_TOKEN) num_classes = len(classes) model = model.to(device) # model = CLIPMultimodalClassifier(num_classes).to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() image_pil = Image.fromarray(image).convert("RGB") # Preprocess image image_tensor = image_transform(image_pil).unsqueeze(0).to(device) # Extract text using EasyOCR ocr_text = ocr_reader.readtext(image, detail=0) combined_text = " ".join(ocr_text) # Join OCR results into one string # Preprocess text for CLIP text_inputs = clip_processor( text=[combined_text], # Text in a list return_tensors="pt", padding="max_length", truncation=True, max_length=77 ).to(device) # Predict with torch.no_grad(): outputs = model(image_tensor, text_inputs["input_ids"]) probabilities = torch.nn.functional.softmax(outputs, dim=1) predicted_class_idx = torch.argmax(probabilities, dim=1).item() # Return results predicted_class = classes[predicted_class_idx] return f"Predicted Class: {predicted_class}\nExtracted Text: {combined_text}" # Create a Gradio interface iface = gr.Interface( fn=lambda image: run_inference(image, model, clip_processor), inputs=gr.Image(type="numpy"), # Updated to use gr.Image outputs="text", # Output is text (predicted class) title="Image Classification", description="Upload an image to get the predicted class using the ViT model." ) # Launch the Gradio app iface.launch()