File size: 5,328 Bytes
4c728e9
882a71b
4c728e9
882a71b
 
 
 
 
4c728e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
882a71b
4c728e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from transformers import OwlViTProcessor, OwlViTForObjectDetection
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import streamlit as st
from PIL import Image
import warnings
import torch
import os
import io


# setttings
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
warnings.filterwarnings('ignore')
st.set_page_config()


class owl_vit:

    def __init__(self, image_path, text, threshold):
        self.image_path = image_path
        self.text = text
        self.threshold = threshold

    def process(self, processor, model):
        image = Image.open(self.image_path)
        if len(image.split()) == 1:
            image = image.convert("RGB")
        inputs = processor(text=[self.text], images=[image], return_tensors="pt")
        outputs = model(**inputs)
        target_sizes = torch.tensor([[image.height, image.width] for image in [image]])
        self.results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
        self.image = image
        return self.result_image()

    def result_image(self):
        boxes, scores, labels = self.results[0]["boxes"], self.results[0]["scores"], self.results[0]["labels"]
        plt.imshow(self.image)
        ax = plt.gca()
        for box, score, label in zip(boxes, scores, labels):
            if score >= self.threshold:
                box = box.detach().numpy()
                color = list(mcolors.CSS4_COLORS.keys())[label]
                ax.add_patch(plt.Rectangle(box[:2], box[2] - box[0], box[3] - box[1], fill=False, color=color, linewidth=3,))
                ax.text(box[0], box[1], f"{self.text[label]}: {round(score.item(), 2)}", fontsize=15, color=color)
        plt.tight_layout()
        img_buf = io.BytesIO()
        plt.savefig(img_buf, format='png')
        image = Image.open(img_buf)
        return image


def load_model():
    with st.spinner('Getting Neruons in Order ...'):
        processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch16")
        model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16")
        return processor, model


def show_detects(image):
    st.title("Results")
    st.image(image, use_column_width=True, caption="Object Detection Results", clamp=True)


def process(upload, text, threshold):

    # save upload to file
    filetype = upload.name.split('.')[-1]
    name = len(os.listdir("images")) + 1
    file_path = os.path.join('images', f'{name}.{filetype}')
    with open(file_path, "wb") as f:
        f.write(upload.getbuffer())

    # predict detections and show results
    detector = owl_vit(file_path, text, threshold)
    results = detector.process(processor, model)
    show_detects(results)

    # clean up - if over 1000 images in folder, delete oldest 1
    if len(os.listdir("images")) > 1000:
        oldest = min(os.listdir("images"), key=os.path.getctime)
        os.remove(os.path.join("images", oldest))


def main(processor, model):

    # splash image
    st.image(os.path.join('refs', 'baseball_labeled.png'), use_column_width=True)

    # title project descriptions
    st.title("OWL-ViT")
    st.markdown("**OWL-ViT** is a zero-shot text-conditioned object detection model. OWL-ViT uses CLIP as its multi-modal \
                backbone, with a ViT-like Transformer to get visual features and a causal language model to get the text features. \
                To use CLIP for detection, OWL-ViT removes the final token pooling layer of the vision model and attaches a \
                lightweight classification and box head to each transformer output token. Open-vocabulary classification \
                is enabled by replacing the fixed classification layer weights with the class-name embeddings obtained \
                from the text model. The authors first train CLIP from scratch and fine-tune it end-to-end with the classification \
                and box heads on standard detection datasets using a bipartite matching loss. One or multiple text queries per image \
                can be used to perform zero-shot text-conditioned object detection.", unsafe_allow_html=True)

    # example
    if st.button("Run the Example Image/Text"):
        with st.spinner('Detecting Objects and Comparing Vocab...'):
            info = owl_vit(os.path.join('refs', 'baseball.jpg'), ["batter", "umpire", "catcher"], 0.50)
            results = info.process(processor, model)
            show_detects(results)
            if st.button("Clear Example"):
                st.markdown("")

    # upload
    col1, col2 = st.columns(2)
    threshold = st.slider('Confidence Threshold', min_value=0.0, max_value=1.0, value=0.1)
    with col1:
        upload = st.file_uploader('Image:', type=['jpg', 'jpeg', 'png'])
    with col2:
        text = st.text_area('Objects to Detect: (comma, seperated)', "batter, umpire, catcher")
        text = [x.strip() for x in text.split(',')]

    # process
    if upload is not None and text is not None:
        filetype = upload.name.split('.')[-1]
        if filetype in ['jpg', 'jpeg', 'png']:
            with st.spinner('Detecting and Counting Single Image...'):
                process(upload, text, threshold)
        else:
            st.warning('Unsupported file type.')


if __name__ == '__main__':
    processor, model = load_model()
    main(processor, model)