Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
import spaces | |
from torchvision import transforms | |
import easyocr | |
from transformers import CLIPProcessor, CLIPModel | |
from huggingface_hub import hf_hub_download | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
model = None | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
classes = ['.ipynb_checkpoints', '2-products-in-one-offer', | |
'2-products-in-one-offer-+-coupon', | |
'an-offer-with-a-coupon', | |
'availability-of-additional-products', | |
'offers-with-a-preliminary-promotional-price', | |
'offers-with-an-additional-deposit-price', | |
'offers-with-an-additional-shipping', | |
'offers-with-dealtype-special_price', | |
'offers-with-different-sizes', | |
'offers-with-money_rebate', | |
'offers-with-percentage_rebate', | |
'offers-with-price-characteristic-(statt)', | |
'offers-with-price-characterization-(uvp)', | |
'offers-with-product-number-(sku)', | |
'offers-with-reward', | |
'offers-with-the-condition_-available-from-such-and-such-a-number', | |
'offers-with-the-old-price-crossed-out', | |
'regular', | |
'scene-with-multiple-offers-+-uvp-price-for-each-offers', | |
'several-products-in-one-offer-with-different-prices', | |
'simple-offers', | |
'stock-offers', | |
'stocks', | |
'travel-booklets', | |
'with-a-product-without-a-price', | |
'with-the-price-of-the-supplemental-bid'] | |
# Image preprocessing (resize and normalize) | |
image_transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)) | |
]) | |
# EasyOCR Reader for text extraction | |
# ocr_reader = easyocr.Reader(['en', 'de'], gpu=False) # Supports English and German | |
ocr_reader = easyocr.Reader(['en', 'de'], model_storage_directory="./") | |
# CLIP Processor | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
# Custom CLIP-based Multimodal Classifier | |
class CLIPMultimodalClassifier(torch.nn.Module): | |
def __init__(self, num_classes): | |
super(CLIPMultimodalClassifier, self).__init__() | |
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
self.fc = torch.nn.Linear(self.clip_model.config.projection_dim, num_classes) | |
def forward(self, images, texts): | |
image_features = self.clip_model.get_image_features(images) | |
text_features = self.clip_model.get_text_features(texts) | |
combined_features = (image_features + text_features) / 2 | |
logits = self.fc(combined_features) | |
return logits | |
# Inference function | |
def run_inference(image, device): | |
# Load image from the Gradio input | |
global model | |
# Initialize model if not already loaded | |
if model is None: | |
model_path = hf_hub_download(repo_id="limitedonly41/offers_26", | |
filename="multi_train_best_model.pth", | |
use_auth_token=HF_TOKEN) | |
num_classes = len(classes) | |
model = CLIPMultimodalClassifier(num_classes).to(device) | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.eval() | |
# Convert image to PIL format and preprocess | |
image_pil = Image.fromarray(image).convert("RGB") | |
image_tensor = image_transform(image_pil).unsqueeze(0).to(device) | |
# Extract text using EasyOCR | |
ocr_text = ocr_reader.readtext(image, detail=0) | |
combined_text = " ".join(ocr_text) | |
# Preprocess text for CLIP | |
text_inputs = clip_processor( | |
text=[combined_text], | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=77 | |
).to(device) | |
# Predict the class | |
with torch.no_grad(): | |
outputs = model(image_tensor, text_inputs["input_ids"]) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
predicted_class_idx = torch.argmax(probabilities, dim=1).item() | |
# Return predicted class and OCR results | |
predicted_class = classes[predicted_class_idx] | |
return f"Predicted Class: {predicted_class}\nExtracted Text: {combined_text}" | |
# Create a Gradio interface | |
iface = gr.Interface( | |
fn=lambda image: run_inference(image, device), | |
inputs=gr.Image(type="numpy"), # Updated to use gr.Image | |
outputs="text", # Output is text (predicted class) | |
title="Image Classification", | |
description="Upload an image to get the predicted class using the ViT model." | |
) | |
# Launch the Gradio app | |
iface.launch() |