Spaces:
Running
Running
from transformers import OwlViTProcessor, OwlViTForObjectDetection | |
import matplotlib.colors as mcolors | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
from PIL import Image | |
import warnings | |
import torch | |
import os | |
import io | |
# setttings | |
os.environ['CUDA_VISIBLE_DEVICES'] = '1' | |
warnings.filterwarnings('ignore') | |
st.set_page_config() | |
class owl_vit: | |
def __init__(self, image_path, text, threshold): | |
self.image_path = image_path | |
self.text = text | |
self.threshold = threshold | |
def process(self, processor, model): | |
image = Image.open(self.image_path) | |
if len(image.split()) == 1: | |
image = image.convert("RGB") | |
inputs = processor(text=[self.text], images=[image], return_tensors="pt") | |
outputs = model(**inputs) | |
target_sizes = torch.tensor([[image.height, image.width] for image in [image]]) | |
self.results = processor.post_process(outputs=outputs, target_sizes=target_sizes) | |
self.image = image | |
return self.result_image() | |
def result_image(self): | |
boxes, scores, labels = self.results[0]["boxes"], self.results[0]["scores"], self.results[0]["labels"] | |
plt.imshow(self.image) | |
ax = plt.gca() | |
for box, score, label in zip(boxes, scores, labels): | |
if score >= self.threshold: | |
box = box.detach().numpy() | |
color = list(mcolors.CSS4_COLORS.keys())[label] | |
ax.add_patch(plt.Rectangle(box[:2], box[2] - box[0], box[3] - box[1], fill=False, color=color, linewidth=3,)) | |
ax.text(box[0], box[1], f"{self.text[label]}: {round(score.item(), 2)}", fontsize=15, color=color) | |
plt.tight_layout() | |
img_buf = io.BytesIO() | |
plt.savefig(img_buf, format='png') | |
image = Image.open(img_buf) | |
return image | |
def load_model(): | |
with st.spinner('Getting Neruons in Order ...'): | |
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch16") | |
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16") | |
return processor, model | |
def show_detects(image): | |
st.title("Results") | |
st.image(image, use_column_width=True, caption="Object Detection Results", clamp=True) | |
def process(upload, text, threshold): | |
# save upload to file | |
filetype = upload.name.split('.')[-1] | |
name = len(os.listdir("images")) + 1 | |
file_path = os.path.join('images', f'{name}.{filetype}') | |
with open(file_path, "wb") as f: | |
f.write(upload.getbuffer()) | |
# predict detections and show results | |
detector = owl_vit(file_path, text, threshold) | |
results = detector.process(processor, model) | |
show_detects(results) | |
# clean up - if over 1000 images in folder, delete oldest 1 | |
if len(os.listdir("images")) > 1000: | |
oldest = min(os.listdir("images"), key=os.path.getctime) | |
os.remove(os.path.join("images", oldest)) | |
def main(processor, model): | |
# splash image | |
st.image(os.path.join('refs', 'baseball_labeled.png'), use_column_width=True) | |
# title project descriptions | |
st.title("OWL-ViT") | |
st.markdown("**OWL-ViT** is a zero-shot text-conditioned object detection model. OWL-ViT uses CLIP as its multi-modal \ | |
backbone, with a ViT-like Transformer to get visual features and a causal language model to get the text features. \ | |
To use CLIP for detection, OWL-ViT removes the final token pooling layer of the vision model and attaches a \ | |
lightweight classification and box head to each transformer output token. Open-vocabulary classification \ | |
is enabled by replacing the fixed classification layer weights with the class-name embeddings obtained \ | |
from the text model. The authors first train CLIP from scratch and fine-tune it end-to-end with the classification \ | |
and box heads on standard detection datasets using a bipartite matching loss. One or multiple text queries per image \ | |
can be used to perform zero-shot text-conditioned object detection.", unsafe_allow_html=True) | |
# example | |
if st.button("Run the Example Image/Text"): | |
with st.spinner('Detecting Objects and Comparing Vocab...'): | |
info = owl_vit(os.path.join('refs', 'baseball.jpg'), ["batter", "umpire", "catcher"], 0.50) | |
results = info.process(processor, model) | |
show_detects(results) | |
if st.button("Clear Example"): | |
st.markdown("") | |
# upload | |
col1, col2 = st.columns(2) | |
threshold = st.slider('Confidence Threshold', min_value=0.0, max_value=1.0, value=0.1) | |
with col1: | |
upload = st.file_uploader('Image:', type=['jpg', 'jpeg', 'png']) | |
with col2: | |
text = st.text_area('Objects to Detect: (comma, seperated)', "batter, umpire, catcher") | |
text = [x.strip() for x in text.split(',')] | |
# process | |
if upload is not None and text is not None: | |
filetype = upload.name.split('.')[-1] | |
if filetype in ['jpg', 'jpeg', 'png']: | |
with st.spinner('Detecting and Counting Single Image...'): | |
process(upload, text, threshold) | |
else: | |
st.warning('Unsupported file type.') | |
if __name__ == '__main__': | |
processor, model = load_model() | |
main(processor, model) | |