limitedonly41's picture
Update app.py
28c504c verified
raw
history blame
4.71 kB
import os
import torch
import numpy as np
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
import spaces
from torchvision import transforms
import easyocr
from transformers import CLIPProcessor, CLIPModel
from huggingface_hub import hf_hub_download
HF_TOKEN = os.environ.get("HF_TOKEN")
model = None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
classes = ['.ipynb_checkpoints', '2-products-in-one-offer',
'2-products-in-one-offer-+-coupon',
'an-offer-with-a-coupon',
'availability-of-additional-products',
'offers-with-a-preliminary-promotional-price',
'offers-with-an-additional-deposit-price',
'offers-with-an-additional-shipping',
'offers-with-dealtype-special_price',
'offers-with-different-sizes',
'offers-with-money_rebate',
'offers-with-percentage_rebate',
'offers-with-price-characteristic-(statt)',
'offers-with-price-characterization-(uvp)',
'offers-with-product-number-(sku)',
'offers-with-reward',
'offers-with-the-condition_-available-from-such-and-such-a-number',
'offers-with-the-old-price-crossed-out',
'regular',
'scene-with-multiple-offers-+-uvp-price-for-each-offers',
'several-products-in-one-offer-with-different-prices',
'simple-offers',
'stock-offers',
'stocks',
'travel-booklets',
'with-a-product-without-a-price',
'with-the-price-of-the-supplemental-bid']
# Image preprocessing (resize and normalize)
image_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# EasyOCR Reader for text extraction
# ocr_reader = easyocr.Reader(['en', 'de'], gpu=False) # Supports English and German
ocr_reader = easyocr.Reader(['en', 'de'], model_storage_directory="./")
# CLIP Processor
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Custom CLIP-based Multimodal Classifier
class CLIPMultimodalClassifier(torch.nn.Module):
def __init__(self, num_classes):
super(CLIPMultimodalClassifier, self).__init__()
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.fc = torch.nn.Linear(self.clip_model.config.projection_dim, num_classes)
def forward(self, images, texts):
image_features = self.clip_model.get_image_features(images)
text_features = self.clip_model.get_text_features(texts)
combined_features = (image_features + text_features) / 2
logits = self.fc(combined_features)
return logits
# Inference function
@spaces.GPU()
def run_inference(image, device):
# Load image from the Gradio input
global model
# Initialize model if not already loaded
if model is None:
model_path = hf_hub_download(repo_id="limitedonly41/offers_26",
filename="multi_train_best_model.pth",
use_auth_token=HF_TOKEN)
num_classes = len(classes)
model = CLIPMultimodalClassifier(num_classes).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
# Convert image to PIL format and preprocess
image_pil = Image.fromarray(image).convert("RGB")
image_tensor = image_transform(image_pil).unsqueeze(0).to(device)
# Extract text using EasyOCR
ocr_text = ocr_reader.readtext(image, detail=0)
combined_text = " ".join(ocr_text)
# Preprocess text for CLIP
text_inputs = clip_processor(
text=[combined_text],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=77
).to(device)
# Predict the class
with torch.no_grad():
outputs = model(image_tensor, text_inputs["input_ids"])
probabilities = torch.nn.functional.softmax(outputs, dim=1)
predicted_class_idx = torch.argmax(probabilities, dim=1).item()
# Return predicted class and OCR results
predicted_class = classes[predicted_class_idx]
return f"Predicted Class: {predicted_class}\nExtracted Text: {combined_text}"
# Create a Gradio interface
iface = gr.Interface(
fn=lambda image: run_inference(image, device),
inputs=gr.Image(type="numpy"), # Updated to use gr.Image
outputs="text", # Output is text (predicted class)
title="Image Classification",
description="Upload an image to get the predicted class using the ViT model."
)
# Launch the Gradio app
iface.launch()