guidel commited on
Commit
7ae7542
1 Parent(s): cf6e303

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import CLIPModel, CLIPProcessor
3
+ import torch
4
+ from PIL import Image
5
+
6
+ #################################
7
+ #### FUNCTIONS
8
+
9
+ def load_clip(model_size='large'):
10
+ if model_size == 'base':
11
+ MODEL_name = 'openai/clip-vit-base-patch32'
12
+ elif model_size == 'large':
13
+ MODEL_name = 'openai/clip-vit-large-patch14'
14
+
15
+ model = CLIPModel.from_pretrained(MODEL_name)
16
+ processor = CLIPProcessor.from_pretrained(MODEL_name)
17
+
18
+ return processor, model
19
+
20
+ def inference_clip(options, image):
21
+
22
+ inputs = processor(text= options, images=image, return_tensors="pt", padding=True)
23
+ with torch.no_grad():
24
+ outputs = model(**inputs)
25
+
26
+ #logits_per_text = outputs.logits_per_text
27
+ logits_per_image = outputs.logits_per_image # this is the image-text similarity score
28
+ probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
29
+
30
+ max_prob_idx = torch.argmax(probs)
31
+ max_prob_option = options[max_prob_idx]
32
+ max_prob = probs[max_prob_idx].item()
33
+ return max_prob_option
34
+
35
+ #################################
36
+ #### LAYOUT
37
+
38
+ CLIP_large = load_clip(model_size='large')
39
+
40
+ picture_file = st.file_uploader("Picture :", type=["jpg", "jpeg", "png"])
41
+ if picture_file is not None:
42
+ image = Image.open(picture_file)
43
+ st.image(image, caption='Please upload an image of the damage', use_column_width=True)
44
+
45
+ #image
46
+ options = st.text_input(label="Please enter the classes", value="")
47
+ options = list(options)
48
+
49
+ # button to launch compute
50
+ if st.button("Compute"):
51
+ clip_processor, clip_model = load_clip(model_size='large')
52
+ result = inference_clip(options = options, image = image)
53
+ st.write(result)