|
import streamlit as st |
|
from transformers import CLIPModel, CLIPProcessor |
|
import torch |
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
def load_clip(model_size='large'): |
|
if model_size == 'base': |
|
MODEL_name = 'openai/clip-vit-base-patch32' |
|
elif model_size == 'large': |
|
MODEL_name = 'openai/clip-vit-large-patch14' |
|
|
|
model = CLIPModel.from_pretrained(MODEL_name) |
|
processor = CLIPProcessor.from_pretrained(MODEL_name) |
|
|
|
return processor, model |
|
|
|
def inference_clip(options, image): |
|
|
|
inputs = processor(text= options, images=image, return_tensors="pt", padding=True) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
logits_per_image = outputs.logits_per_image |
|
probs = logits_per_image.softmax(dim=1) |
|
|
|
max_prob_idx = torch.argmax(probs) |
|
max_prob_option = options[max_prob_idx] |
|
max_prob = probs[max_prob_idx].item() |
|
return max_prob_option |
|
|
|
|
|
|
|
|
|
CLIP_large = load_clip(model_size='large') |
|
|
|
picture_file = st.file_uploader("Picture :", type=["jpg", "jpeg", "png"]) |
|
if picture_file is not None: |
|
image = Image.open(picture_file) |
|
st.image(image, caption='Please upload an image of the damage', use_column_width=True) |
|
|
|
|
|
options = st.text_input(label="Please enter the classes", value="") |
|
options = list(options) |
|
|
|
|
|
if st.button("Compute"): |
|
clip_processor, clip_model = load_clip(model_size='large') |
|
result = inference_clip(options = options, image = image) |
|
st.write(result) |
|
|