Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,26 +1,39 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import CLIPProcessor, CLIPModel
|
3 |
from PIL import Image
|
|
|
4 |
|
5 |
# Load the model and processor
|
6 |
model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
|
7 |
processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
|
8 |
|
9 |
def classify_image(image):
|
10 |
-
#
|
11 |
-
|
|
|
|
|
|
|
|
|
12 |
# Perform the inference
|
13 |
outputs = model(**inputs)
|
|
|
14 |
# Postprocess the outputs
|
15 |
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
16 |
probs = logits_per_image.softmax(dim=1) # we can use softmax to get probabilities
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
# Define Gradio interface
|
20 |
iface = gr.Interface(
|
21 |
fn=classify_image,
|
22 |
inputs=gr.Image(type="pil"),
|
23 |
-
outputs="
|
24 |
title="Geolocal StreetCLIP Classification",
|
25 |
description="Upload an image to classify using Geolocal StreetCLIP"
|
26 |
)
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import CLIPProcessor, CLIPModel
|
3 |
from PIL import Image
|
4 |
+
import torch
|
5 |
|
6 |
# Load the model and processor
|
7 |
model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
|
8 |
processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
|
9 |
|
10 |
def classify_image(image):
|
11 |
+
# Example labels for classification
|
12 |
+
labels = ["a photo of a cat", "a photo of a dog", "a photo of a car", "a photo of a tree"]
|
13 |
+
|
14 |
+
# Preprocess the image and text
|
15 |
+
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
|
16 |
+
|
17 |
# Perform the inference
|
18 |
outputs = model(**inputs)
|
19 |
+
|
20 |
# Postprocess the outputs
|
21 |
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
22 |
probs = logits_per_image.softmax(dim=1) # we can use softmax to get probabilities
|
23 |
+
|
24 |
+
# Convert the probabilities to a list
|
25 |
+
probs_list = probs.tolist()[0]
|
26 |
+
|
27 |
+
# Create a dictionary of labels and probabilities
|
28 |
+
result = {label: prob for label, prob in zip(labels, probs_list)}
|
29 |
+
|
30 |
+
return result
|
31 |
|
32 |
# Define Gradio interface
|
33 |
iface = gr.Interface(
|
34 |
fn=classify_image,
|
35 |
inputs=gr.Image(type="pil"),
|
36 |
+
outputs="label",
|
37 |
title="Geolocal StreetCLIP Classification",
|
38 |
description="Upload an image to classify using Geolocal StreetCLIP"
|
39 |
)
|