DGurgurov commited on
Commit
b1f6a67
1 Parent(s): c901e48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -32
app.py CHANGED
@@ -1,51 +1,60 @@
1
  import gradio as gr
 
2
  from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
3
- from torchvision.transforms import functional as F
4
  from PIL import Image
 
 
 
5
 
6
- # Load tokenizer and model
7
  processor = AutoProcessor.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets")
8
  model = AutoModelForZeroShotImageClassification.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets")
9
 
10
- # Define label mappings
11
- labels = set(dataset['train']['label'])
 
 
12
  label2id = {label: i for i, label in enumerate(labels)}
13
  id2label = {i: label for label, i in label2id.items()}
14
 
15
- # Function to preprocess image
 
 
 
 
 
 
 
 
16
  def preprocess_image(image):
17
- image = Image.fromarray(image) # Convert numpy array to PIL Image
18
- image = image.convert("RGB") # Ensure image is RGB (some images might be grayscale)
19
- image = image.resize((224, 224)) # Resize image to match CLIP model input size
20
- image = F.to_tensor(image) # Convert PIL Image to PyTorch tensor
21
- image = F.normalize(image, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize image
22
- return image
23
-
24
- # Function to predict using CLIP model
25
- def predict(image):
26
- # Preprocess image
27
  image = preprocess_image(image)
28
-
29
- # Prepare input for the model
30
- inputs = processor(images=image.unsqueeze(0), labels=labels, return_tensors="pt")
31
-
32
- # Perform inference
33
- outputs = model(**inputs)
34
-
35
  # Get predicted label
36
- logits_per_image = outputs.logits_per_image
37
- predicted_class = labels[torch.argmax(logits_per_image, dim=-1)]
38
-
39
- return predicted_class
40
 
41
- # Define Gradio interface
42
  iface = gr.Interface(
43
- fn=predict,
44
- inputs=gr.Image(shape=(224, 224)),
45
- outputs=gr.Textbox(),
46
  title="Animal Classifier",
47
- description="CLIP-ViT model fine-tuned on Oxford Pets dataset to classify animals."
48
  )
49
 
50
- # Launch the Gradio app
51
  iface.launch()
 
1
  import gradio as gr
2
+ import torch
3
  from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
4
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
5
  from PIL import Image
6
+ import requests
7
+ from io import BytesIO
8
+ from dataset import load_dataset
9
 
10
+ # Load your fine-tuned model and dataset
11
  processor = AutoProcessor.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets")
12
  model = AutoModelForZeroShotImageClassification.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets")
13
 
14
+ # Load dataset to get labels
15
+ dataset = load_dataset("pcuenq/oxford-pets") # Adjust dataset loading as per your setup
16
+
17
+ labels = list(set(dataset['train']['label']))
18
  label2id = {label: i for i, label in enumerate(labels)}
19
  id2label = {i: label for label, i in label2id.items()}
20
 
21
+ # Define transformations for input images
22
+ transform = Compose([
23
+ Resize((224, 224)),
24
+ CenterCrop(224),
25
+ ToTensor(),
26
+ Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
27
+ ])
28
+
29
+ # Function to preprocess the input image
30
  def preprocess_image(image):
31
+ image = Image.open(BytesIO(image))
32
+ image = transform(image)
33
+ return image.unsqueeze(0)
34
+
35
+ # Function to classify image using CLIP model
36
+ def classify_image(image):
37
+ # Preprocess the image
 
 
 
38
  image = preprocess_image(image)
39
+
40
+ # Run inference
41
+ with torch.no_grad():
42
+ outputs = model(image)
43
+
 
 
44
  # Get predicted label
45
+ predicted_label_id = torch.argmax(outputs, dim=1).item()
46
+ predicted_label = id2label[predicted_label_id]
47
+
48
+ return predicted_label
49
 
50
+ # Gradio interface
51
  iface = gr.Interface(
52
+ fn=classify_image,
53
+ inputs=gr.Image(label="Upload a picture of an animal"),
54
+ outputs=gr.Textbox(label="Predicted Animal"),
55
  title="Animal Classifier",
56
+ description="CLIP-based model fine-tuned on Oxford Pets dataset to classify animals.",
57
  )
58
 
59
+ # Launch the Gradio interface
60
  iface.launch()