mrsarthakgupta commited on
Commit
1fd3d1a
1 Parent(s): 34f6368

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -0
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
3
+ import torch
4
+ from PIL import Image
5
+
6
+ # Load pre-trained model and feature extractor
7
+ model_name = "google/vit-base-patch16-224"
8
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
9
+ model = AutoModelForImageClassification.from_pretrained(model_name)
10
+
11
+ def classify_image(image):
12
+ # Preprocess the image
13
+ inputs = feature_extractor(images=image, return_tensors="pt")
14
+
15
+ # Make prediction
16
+ with torch.no_grad():
17
+ outputs = model(**inputs)
18
+
19
+ # Get the predicted class
20
+ predicted_class_idx = outputs.logits.argmax(-1).item()
21
+ predicted_class = model.config.id2label[predicted_class_idx]
22
+
23
+ return predicted_class
24
+
25
+ # Create Gradio interface
26
+ iface = gr.Interface(
27
+ fn=classify_image,
28
+ inputs=gr.Image(type="pil"),
29
+ outputs=gr.Textbox(label="Predicted Class"),
30
+ title="Image Classification",
31
+ description="Upload an image to classify it using a pre-trained ViT model."
32
+ )
33
+
34
+ # Launch the app
35
+ iface.launch()