DGurgurov commited on
Commit
ae3733a
1 Parent(s): ea824f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -25
app.py CHANGED
@@ -1,33 +1,19 @@
1
- import gradio as gr
2
- from transformers import CLIPProcessor, CLIPModel
3
- from PIL import Image
4
- import torch
5
 
6
- # Load the model and processor
7
- model_id = "DGurgurov/clip-vit-base-patch32-oxford-pets"
8
- model = CLIPModel.from_pretrained(model_id)
9
- processor_id = "openai/clip-vit-base-patch32"
10
- processor = CLIPProcessor.from_pretrained(processor_id)
11
 
12
- # Define the inference function
13
  def predict(image):
14
- # Preprocess the image
15
- image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL Image
16
- inputs = processor(images=image, return_tensors="pt")
17
-
18
- # Perform inference
19
- outputs = model(**inputs)
20
- logits_per_image = outputs.logits_per_image
21
- probs = torch.nn.functional.softmax(logits_per_image, dim=1)
22
-
23
- # Get top 5 class predictions
24
- labels = processor.tokenizer.convert_ids_to_tokens(outputs.logits_per_image.argmax(dim=-1))
25
- top_5 = labels[torch.topk(probs[0], 5).indices]
26
-
27
- return {f"Class {i}": label for i, label in enumerate(top_5)}
28
 
29
  # Define Gradio interface
30
- image = gr.components.Image(type="pil")
31
  label = gr.components.Label(num_top_classes=5)
32
 
33
  interface = gr.Interface(
 
1
+ from transformers import pipeline
 
 
 
2
 
3
+ # Define the pipeline for image classification
4
+ pipe = pipeline("image-classification", model="DGurgurov/clip-vit-base-patch32-oxford-pets")
 
 
 
5
 
6
+ # Define the predict function using the pipeline
7
  def predict(image):
8
+ # Perform inference using the pipeline
9
+ results = pipe(image)
10
+ return {f"Class {i}": result['label'] for i, result in enumerate(results)}
11
+
12
+ # Now you can use this predict function in your Gradio interface
13
+ import gradio as gr
 
 
 
 
 
 
 
 
14
 
15
  # Define Gradio interface
16
+ image = gr.components.Image()
17
  label = gr.components.Label(num_top_classes=5)
18
 
19
  interface = gr.Interface(