|
import requests |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
|
|
from transformers import CLIPProcessor, CLIPModel |
|
|
|
model = CLIPModel.from_pretrained("flaviagiammarino/pubmed-clip-vit-base-patch32") |
|
processor = CLIPProcessor.from_pretrained("flaviagiammarino/pubmed-clip-vit-base-patch32") |
|
|
|
url = "https://d168r5mdg5gtkq.cloudfront.net/medpix/img/full/synpic9078.jpg" |
|
image = Image.open(requests.get(url, stream=True).raw) |
|
text = ["Chest X-Ray", "Brain MRI", "Abdominal CT Scan"] |
|
|
|
inputs = processor(text=text, images=image, return_tensors="pt", padding=True) |
|
probs = model(**inputs).logits_per_image.softmax(dim=1).squeeze() |
|
|
|
plt.subplots() |
|
plt.imshow(image) |
|
plt.title("".join([x[0] + ": " + x[1] + "\n" for x in zip(text, [format(prob, ".4%") for prob in probs])])) |
|
plt.axis("off") |
|
plt.tight_layout() |
|
plt.show() |
|
|