Spaces:
Running
Running
File size: 5,328 Bytes
4c728e9 882a71b 4c728e9 882a71b 4c728e9 882a71b 4c728e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from transformers import OwlViTProcessor, OwlViTForObjectDetection
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import streamlit as st
from PIL import Image
import warnings
import torch
import os
import io
# setttings
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
warnings.filterwarnings('ignore')
st.set_page_config()
class owl_vit:
def __init__(self, image_path, text, threshold):
self.image_path = image_path
self.text = text
self.threshold = threshold
def process(self, processor, model):
image = Image.open(self.image_path)
if len(image.split()) == 1:
image = image.convert("RGB")
inputs = processor(text=[self.text], images=[image], return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([[image.height, image.width] for image in [image]])
self.results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
self.image = image
return self.result_image()
def result_image(self):
boxes, scores, labels = self.results[0]["boxes"], self.results[0]["scores"], self.results[0]["labels"]
plt.imshow(self.image)
ax = plt.gca()
for box, score, label in zip(boxes, scores, labels):
if score >= self.threshold:
box = box.detach().numpy()
color = list(mcolors.CSS4_COLORS.keys())[label]
ax.add_patch(plt.Rectangle(box[:2], box[2] - box[0], box[3] - box[1], fill=False, color=color, linewidth=3,))
ax.text(box[0], box[1], f"{self.text[label]}: {round(score.item(), 2)}", fontsize=15, color=color)
plt.tight_layout()
img_buf = io.BytesIO()
plt.savefig(img_buf, format='png')
image = Image.open(img_buf)
return image
def load_model():
with st.spinner('Getting Neruons in Order ...'):
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch16")
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16")
return processor, model
def show_detects(image):
st.title("Results")
st.image(image, use_column_width=True, caption="Object Detection Results", clamp=True)
def process(upload, text, threshold):
# save upload to file
filetype = upload.name.split('.')[-1]
name = len(os.listdir("images")) + 1
file_path = os.path.join('images', f'{name}.{filetype}')
with open(file_path, "wb") as f:
f.write(upload.getbuffer())
# predict detections and show results
detector = owl_vit(file_path, text, threshold)
results = detector.process(processor, model)
show_detects(results)
# clean up - if over 1000 images in folder, delete oldest 1
if len(os.listdir("images")) > 1000:
oldest = min(os.listdir("images"), key=os.path.getctime)
os.remove(os.path.join("images", oldest))
def main(processor, model):
# splash image
st.image(os.path.join('refs', 'baseball_labeled.png'), use_column_width=True)
# title project descriptions
st.title("OWL-ViT")
st.markdown("**OWL-ViT** is a zero-shot text-conditioned object detection model. OWL-ViT uses CLIP as its multi-modal \
backbone, with a ViT-like Transformer to get visual features and a causal language model to get the text features. \
To use CLIP for detection, OWL-ViT removes the final token pooling layer of the vision model and attaches a \
lightweight classification and box head to each transformer output token. Open-vocabulary classification \
is enabled by replacing the fixed classification layer weights with the class-name embeddings obtained \
from the text model. The authors first train CLIP from scratch and fine-tune it end-to-end with the classification \
and box heads on standard detection datasets using a bipartite matching loss. One or multiple text queries per image \
can be used to perform zero-shot text-conditioned object detection.", unsafe_allow_html=True)
# example
if st.button("Run the Example Image/Text"):
with st.spinner('Detecting Objects and Comparing Vocab...'):
info = owl_vit(os.path.join('refs', 'baseball.jpg'), ["batter", "umpire", "catcher"], 0.50)
results = info.process(processor, model)
show_detects(results)
if st.button("Clear Example"):
st.markdown("")
# upload
col1, col2 = st.columns(2)
threshold = st.slider('Confidence Threshold', min_value=0.0, max_value=1.0, value=0.1)
with col1:
upload = st.file_uploader('Image:', type=['jpg', 'jpeg', 'png'])
with col2:
text = st.text_area('Objects to Detect: (comma, seperated)', "batter, umpire, catcher")
text = [x.strip() for x in text.split(',')]
# process
if upload is not None and text is not None:
filetype = upload.name.split('.')[-1]
if filetype in ['jpg', 'jpeg', 'png']:
with st.spinner('Detecting and Counting Single Image...'):
process(upload, text, threshold)
else:
st.warning('Unsupported file type.')
if __name__ == '__main__':
processor, model = load_model()
main(processor, model)
|