Spaces:
Running
Running
File size: 3,622 Bytes
1366c30 a269b46 d27f40c 1366c30 d27f40c a269b46 d27f40c cffabcf 0f2db82 a269b46 cffabcf a269b46 cffabcf a269b46 cffabcf 0f2db82 cffabcf 1366c30 d5345e2 e64dbd8 aa31199 01973e8 e64dbd8 01973e8 e810c3c a269b46 e810c3c 01973e8 e64dbd8 0f2db82 e64dbd8 a269b46 e64dbd8 a269b46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import os
import token
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import streamlit as st
from transformers import CLIPProcessor, AutoTokenizer
from medclip.modeling_hybrid_clip import FlaxHybridCLIP
@st.cache_resource
def load_model():
model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco", _do_init=True)
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
return model, tokenizer
@st.cache_resource
def load_image_embeddings():
embeddings_df = pd.read_hdf('feature_store/image_embeddings_large.hdf', key='emb')
image_embeds = np.stack(embeddings_df['image_embedding'])
image_files = np.asarray(embeddings_df['files'].tolist())
return image_files, image_embeds
k = 5
img_dir = './images'
st.sidebar.header("MedCLIP")
st.sidebar.image("./assets/logo.png", width=100)
st.sidebar.empty()
st.sidebar.markdown("""Search for medical images with natural language powered by a CLIP model [[Model Card]](https://huggingface.co/flax-community/medclip-roco) finetuned on the
[Radiology Objects in COntext (ROCO) dataset](https://github.com/razorx89/roco-dataset).""")
st.sidebar.markdown("Example queries:")
# * `ultrasound scans`π
# * `pathology`π
# * `pancreatic carcinoma`π
# * `PET scan`π""")
ex1_button = st.sidebar.button("π pathology")
ex2_button = st.sidebar.button("π ultrasound scans")
ex3_button = st.sidebar.button("π pancreatic carcinoma")
ex4_button = st.sidebar.button("π PET scan")
k_slider = st.sidebar.slider("Number of images", min_value=1, max_value=10, value=5)
st.sidebar.markdown("Kaushalya Madhawa, 2021")
st.title("MedCLIP π©Ί")
# st.image("./assets/logo.png", width=100)
# st.markdown("""Search for medical images with natural language powered by a CLIP model [[Model Card]](https://huggingface.co/flax-community/medclip-roco) finetuned on the
# [Radiology Objects in COntext (ROCO) dataset](https://github.com/razorx89/roco-dataset).""")
# st.markdown("""Example queries:
# * `ultrasound scans`π
# * `pathology`π
# * `pancreatic carcinoma`π
# * `PET scan`π""")
text_value = ''
if ex1_button:
text_value = 'pathology'
elif ex2_button:
text_value = 'ultrasound scans'
elif ex3_button:
text_value = 'pancreatic carcinoma'
elif ex4_button:
text_value = 'PET scan'
image_list, image_embeddings = load_image_embeddings()
model, tokenizer = load_model()
query = st.text_input("Enter your query here:", value=text_value)
dot_prod = None
if len(query)==0:
query = text_value
if st.button("Search") or k_slider:
if len(query)==0:
st.write("Please enter a valid search query")
else:
with st.spinner(f"Searching ROCO test set for {query}..."):
k = k_slider
inputs = tokenizer(text=[query], return_tensors="jax", padding=True)
# st.write(f"Query inputs: {inputs}")
query_embedding = model.get_text_features(**inputs)
query_embedding = np.asarray(query_embedding)
query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
topk_images = dot_prod.argsort()[-k:]
matching_images = image_list[topk_images]
top_scores = 1. - dot_prod[topk_images]
#show images
for img_path, score in zip(matching_images, top_scores):
img = plt.imread(os.path.join(img_dir, img_path))
st.image(img, width=300)
st.write(f"{img_path} ({score:.2f})")
|