Spaces:
Sleeping
Sleeping
first commit
Browse files- Aptfile +5 -0
- Procfile +1 -0
- README.md +16 -13
- app.py +132 -0
- images/1.png +0 -0
- images/2.png +0 -0
- images/3.jpg +0 -0
- images/4.jpg +0 -0
- images/5.jpg +0 -0
- images/6.jpg +0 -0
- memory_test.py +58 -0
- plot.png +0 -0
- refs/baseball.jpg +0 -0
- refs/baseball_labeled.png +0 -0
- requirements.txt +136 -0
- setup.sh +10 -0
Aptfile
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
libsm6
|
2 |
+
libgl1
|
3 |
+
libxrender1
|
4 |
+
libfontconfig1
|
5 |
+
libice6
|
Procfile
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
web: sh setup.sh && streamlit run main.py
|
README.md
CHANGED
@@ -1,13 +1,16 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
1 |
+
# OwL-Vit Streamlit App
|
2 |
+
|
3 |
+
## Summary
|
4 |
+
An application that ultilizes OWL-Vit and Streamlit to understand text to detect objects in images.
|
5 |
+
|
6 |
+
Deployed and ready to be tested on :
|
7 |
+
|
8 |
+
[![Open in Streamlit](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://www.matthewciolino.com/)
|
9 |
+
|
10 |
+
## Example
|
11 |
+
|
12 |
+
![Baseball Field Picture](refs/baseball_labeled.png)
|
13 |
+
|
14 |
+
## Made Possible with Adapatation from:
|
15 |
+
#### Implementation -> https://huggingface.co/docs/transformers/model_doc/owlvit 🙏
|
16 |
+
#### Paper (Yolo) -> https://arxiv.org/abs/2205.06230
|
app.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
4 |
+
import warnings
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from io import BytesIO
|
8 |
+
import streamlit as st
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import io
|
11 |
+
import matplotlib.colors as mcolors
|
12 |
+
|
13 |
+
|
14 |
+
# setttings
|
15 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
|
16 |
+
warnings.filterwarnings('ignore')
|
17 |
+
st.set_page_config()
|
18 |
+
|
19 |
+
|
20 |
+
class owl_vit:
|
21 |
+
|
22 |
+
def __init__(self, image_path, text, threshold):
|
23 |
+
self.image_path = image_path
|
24 |
+
self.text = text
|
25 |
+
self.threshold = threshold
|
26 |
+
|
27 |
+
def process(self, processor, model):
|
28 |
+
image = Image.open(self.image_path)
|
29 |
+
if len(image.split()) == 1:
|
30 |
+
image = image.convert("RGB")
|
31 |
+
inputs = processor(text=[self.text], images=[image], return_tensors="pt")
|
32 |
+
outputs = model(**inputs)
|
33 |
+
target_sizes = torch.tensor([[image.height, image.width] for image in [image]])
|
34 |
+
self.results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
|
35 |
+
self.image = image
|
36 |
+
return self.result_image()
|
37 |
+
|
38 |
+
def result_image(self):
|
39 |
+
boxes, scores, labels = self.results[0]["boxes"], self.results[0]["scores"], self.results[0]["labels"]
|
40 |
+
plt.imshow(self.image)
|
41 |
+
ax = plt.gca()
|
42 |
+
for box, score, label in zip(boxes, scores, labels):
|
43 |
+
if score >= self.threshold:
|
44 |
+
box = box.detach().numpy()
|
45 |
+
color = list(mcolors.CSS4_COLORS.keys())[label]
|
46 |
+
ax.add_patch(plt.Rectangle(box[:2], box[2] - box[0], box[3] - box[1], fill=False, color=color, linewidth=3,))
|
47 |
+
ax.text(box[0], box[1], f"{self.text[label]}: {round(score.item(), 2)}", fontsize=15, color=color)
|
48 |
+
plt.tight_layout()
|
49 |
+
img_buf = io.BytesIO()
|
50 |
+
plt.savefig(img_buf, format='png')
|
51 |
+
image = Image.open(img_buf)
|
52 |
+
return image
|
53 |
+
|
54 |
+
|
55 |
+
def load_model():
|
56 |
+
with st.spinner('Getting Neruons in Order ...'):
|
57 |
+
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch16")
|
58 |
+
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16")
|
59 |
+
return processor, model
|
60 |
+
|
61 |
+
|
62 |
+
def show_detects(image):
|
63 |
+
st.title("Results")
|
64 |
+
st.image(image, use_column_width=True, caption="Object Detection Results", clamp=True)
|
65 |
+
|
66 |
+
def process(upload, text, threshold):
|
67 |
+
|
68 |
+
# save upload to file
|
69 |
+
filetype = upload.name.split('.')[-1]
|
70 |
+
name = len(os.listdir("images")) + 1
|
71 |
+
file_path = os.path.join('images', f'{name}.{filetype}')
|
72 |
+
with open(file_path, "wb") as f:
|
73 |
+
f.write(upload.getbuffer())
|
74 |
+
|
75 |
+
# predict detections and show results
|
76 |
+
detector = owl_vit(file_path, text, threshold)
|
77 |
+
results = detector.process(processor, model)
|
78 |
+
show_detects(results)
|
79 |
+
|
80 |
+
# clean up - if over 1000 images in folder, delete oldest 1
|
81 |
+
if len(os.listdir("images")) > 1000:
|
82 |
+
oldest = min(os.listdir("images"), key=os.path.getctime)
|
83 |
+
os.remove(os.path.join("images", oldest))
|
84 |
+
|
85 |
+
|
86 |
+
def main(processor, model):
|
87 |
+
|
88 |
+
# splash image
|
89 |
+
st.image(os.path.join('refs', 'baseball_labeled.png'), use_column_width=True)
|
90 |
+
|
91 |
+
# title project descriptions
|
92 |
+
st.title("OWL-ViT")
|
93 |
+
st.markdown("**OWL-ViT** is a zero-shot text-conditioned object detection model. OWL-ViT uses CLIP as its multi-modal \
|
94 |
+
backbone, with a ViT-like Transformer to get visual features and a causal language model to get the text features. \
|
95 |
+
To use CLIP for detection, OWL-ViT removes the final token pooling layer of the vision model and attaches a \
|
96 |
+
lightweight classification and box head to each transformer output token. Open-vocabulary classification \
|
97 |
+
is enabled by replacing the fixed classification layer weights with the class-name embeddings obtained \
|
98 |
+
from the text model. The authors first train CLIP from scratch and fine-tune it end-to-end with the classification \
|
99 |
+
and box heads on standard detection datasets using a bipartite matching loss. One or multiple text queries per image \
|
100 |
+
can be used to perform zero-shot text-conditioned object detection.", unsafe_allow_html=True)
|
101 |
+
|
102 |
+
# example
|
103 |
+
if st.button("Run the Example Image/Text"):
|
104 |
+
with st.spinner('Detecting Objects and Comparing Vocab...'):
|
105 |
+
info = owl_vit(os.path.join('refs', 'baseball.jpg'), ["batter", "umpire", "catcher"], 0.50)
|
106 |
+
results = info.process(processor, model)
|
107 |
+
show_detects(results)
|
108 |
+
if st.button("Clear Example"):
|
109 |
+
st.markdown("")
|
110 |
+
|
111 |
+
# upload
|
112 |
+
col1, col2 = st.columns(2)
|
113 |
+
threshold = st.slider('Confidence Threshold', min_value=0.0, max_value=1.0, value=0.1)
|
114 |
+
with col1:
|
115 |
+
upload = st.file_uploader('Image:', type=['jpg', 'jpeg', 'png'])
|
116 |
+
with col2:
|
117 |
+
text = st.text_area('Objects to Detect: (comma, seperated)', "batter, umpire, catcher")
|
118 |
+
text = [x.strip() for x in text.split(',')]
|
119 |
+
|
120 |
+
# process
|
121 |
+
if upload is not None and text is not None:
|
122 |
+
filetype = upload.name.split('.')[-1]
|
123 |
+
if filetype in ['jpg', 'jpeg', 'png']:
|
124 |
+
with st.spinner('Detecting and Counting Single Image...'):
|
125 |
+
process(upload, text, threshold)
|
126 |
+
else:
|
127 |
+
st.warning('Unsupported file type.')
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == '__main__':
|
131 |
+
processor, model = load_model()
|
132 |
+
main(processor, model)
|
images/1.png
ADDED
images/2.png
ADDED
images/3.jpg
ADDED
images/4.jpg
ADDED
images/5.jpg
ADDED
images/6.jpg
ADDED
memory_test.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# USAGE: python run_psrecord.py <PID> --plot plot.png --log activity.txt
|
2 |
+
|
3 |
+
from psrecord.main import monitor
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
|
7 |
+
def main():
|
8 |
+
# copied from C:\Users\user\anaconda3\envs\tfod\Lib\site-packages\psrecord\main.py
|
9 |
+
parser = argparse.ArgumentParser(
|
10 |
+
description='Record CPU and memory usage for a process')
|
11 |
+
|
12 |
+
parser.add_argument('process_id_or_command', type=str,
|
13 |
+
help='the process id or command')
|
14 |
+
|
15 |
+
parser.add_argument('--log', type=str,
|
16 |
+
help='output the statistics to a file')
|
17 |
+
|
18 |
+
parser.add_argument('--plot', type=str,
|
19 |
+
help='output the statistics to a plot')
|
20 |
+
|
21 |
+
parser.add_argument('--duration', type=float,
|
22 |
+
help='how long to record for (in seconds). If not '
|
23 |
+
'specified, the recording is continuous until '
|
24 |
+
'the job exits.')
|
25 |
+
|
26 |
+
parser.add_argument('--interval', type=float,
|
27 |
+
help='how long to wait between each sample (in '
|
28 |
+
'seconds). By default the process is sampled '
|
29 |
+
'as often as possible.')
|
30 |
+
|
31 |
+
parser.add_argument('--include-children',
|
32 |
+
help='include sub-processes in statistics (results '
|
33 |
+
'in a slower maximum sampling rate).',
|
34 |
+
action='store_true')
|
35 |
+
|
36 |
+
args = parser.parse_args()
|
37 |
+
|
38 |
+
# Attach to process
|
39 |
+
try:
|
40 |
+
pid = int(args.process_id_or_command)
|
41 |
+
print("Attaching to process {0}".format(pid))
|
42 |
+
sprocess = None
|
43 |
+
except Exception:
|
44 |
+
import subprocess
|
45 |
+
command = args.process_id_or_command
|
46 |
+
print("Starting up command '{0}' and attaching to process"
|
47 |
+
.format(command))
|
48 |
+
sprocess = subprocess.Popen(command, shell=True)
|
49 |
+
pid = sprocess.pid
|
50 |
+
|
51 |
+
monitor(pid, logfile=args.log, plot=args.plot, duration=args.duration,
|
52 |
+
interval=args.interval, include_children=args.include_children)
|
53 |
+
|
54 |
+
if sprocess is not None:
|
55 |
+
sprocess.kill()
|
56 |
+
|
57 |
+
if __name__ == '__main__':
|
58 |
+
main()
|
plot.png
ADDED
refs/baseball.jpg
ADDED
refs/baseball_labeled.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
altair==4.1.0
|
2 |
+
anyio==3.5.0
|
3 |
+
argon2-cffi==21.3.0
|
4 |
+
argon2-cffi-bindings==21.2.0
|
5 |
+
asttokens==2.2.1
|
6 |
+
attrs==22.1.0
|
7 |
+
autopep8==1.6.0
|
8 |
+
backcall==0.2.0
|
9 |
+
backports.functools-lru-cache==1.6.4
|
10 |
+
beautifulsoup4==4.11.1
|
11 |
+
bleach==4.1.0
|
12 |
+
blinker==1.4
|
13 |
+
Bottleneck==1.3.5
|
14 |
+
brotlipy==0.7.0
|
15 |
+
cachetools==4.2.2
|
16 |
+
certifi==2022.9.24
|
17 |
+
cffi==1.15.1
|
18 |
+
charset-normalizer==2.0.4
|
19 |
+
click==8.0.4
|
20 |
+
colorama==0.4.5
|
21 |
+
comm==0.1.2
|
22 |
+
commonmark==0.9.1
|
23 |
+
contourpy==1.0.5
|
24 |
+
cryptography==38.0.1
|
25 |
+
cycler==0.11.0
|
26 |
+
debugpy==1.5.1
|
27 |
+
decorator==5.1.1
|
28 |
+
defusedxml==0.7.1
|
29 |
+
entrypoints==0.4
|
30 |
+
executing==1.2.0
|
31 |
+
fastjsonschema==2.16.2
|
32 |
+
filelock==3.6.0
|
33 |
+
flit_core==3.6.0
|
34 |
+
fonttools==4.25.0
|
35 |
+
future==0.18.2
|
36 |
+
gitdb==4.0.7
|
37 |
+
GitPython==3.1.18
|
38 |
+
huggingface-hub==0.10.1
|
39 |
+
idna==3.4
|
40 |
+
importlib-metadata==4.11.3
|
41 |
+
ipykernel==6.19.3
|
42 |
+
ipython==8.7.0
|
43 |
+
ipython-genutils==0.2.0
|
44 |
+
ipywidgets==7.6.5
|
45 |
+
jedi==0.18.2
|
46 |
+
Jinja2==3.1.2
|
47 |
+
jsonschema==4.16.0
|
48 |
+
jupyter_client==7.4.8
|
49 |
+
jupyter_core==5.1.0
|
50 |
+
jupyter-server==1.18.1
|
51 |
+
jupyterlab-pygments==0.1.2
|
52 |
+
jupyterlab-widgets==1.0.0
|
53 |
+
kiwisolver==1.4.2
|
54 |
+
lxml==4.9.1
|
55 |
+
MarkupSafe==2.1.1
|
56 |
+
matplotlib==3.6.2
|
57 |
+
matplotlib-inline==0.1.6
|
58 |
+
mistune==0.8.4
|
59 |
+
mkl-fft==1.3.1
|
60 |
+
mkl-random==1.2.2
|
61 |
+
mkl-service==2.4.0
|
62 |
+
munkres==1.1.4
|
63 |
+
nbclassic==0.4.8
|
64 |
+
nbclient==0.5.13
|
65 |
+
nbconvert==6.5.4
|
66 |
+
nbformat==5.7.0
|
67 |
+
nest-asyncio==1.5.6
|
68 |
+
notebook==6.5.2
|
69 |
+
notebook_shim==0.2.2
|
70 |
+
numexpr==2.8.4
|
71 |
+
numpy==1.22.3
|
72 |
+
packaging==21.3
|
73 |
+
pandas==1.5.2
|
74 |
+
pandocfilters==1.5.0
|
75 |
+
parso==0.8.3
|
76 |
+
pickleshare==0.7.5
|
77 |
+
Pillow==9.3.0
|
78 |
+
pip==22.3.1
|
79 |
+
platformdirs==2.6.0
|
80 |
+
ply==3.11
|
81 |
+
prometheus-client==0.14.1
|
82 |
+
prompt-toolkit==3.0.36
|
83 |
+
protobuf==3.20.1
|
84 |
+
psutil==5.9.0
|
85 |
+
pure-eval==0.2.2
|
86 |
+
pyarrow==8.0.0
|
87 |
+
pycodestyle==2.8.0
|
88 |
+
pycparser==2.21
|
89 |
+
pydeck==0.7.1
|
90 |
+
Pygments==2.13.0
|
91 |
+
Pympler==0.9
|
92 |
+
pyOpenSSL==22.0.0
|
93 |
+
pyparsing==3.0.9
|
94 |
+
PyQt5==5.15.7
|
95 |
+
PyQt5-sip==12.11.0
|
96 |
+
pyrsistent==0.18.0
|
97 |
+
PySocks==1.7.1
|
98 |
+
python-dateutil==2.8.2
|
99 |
+
pytz==2022.1
|
100 |
+
PyYAML==6.0
|
101 |
+
pyzmq==23.2.0
|
102 |
+
regex==2022.7.9
|
103 |
+
requests==2.28.1
|
104 |
+
rich==12.5.1
|
105 |
+
semver==2.13.0
|
106 |
+
Send2Trash==1.8.0
|
107 |
+
setuptools==65.5.0
|
108 |
+
sip==6.6.2
|
109 |
+
six==1.16.0
|
110 |
+
smmap==4.0.0
|
111 |
+
sniffio==1.2.0
|
112 |
+
soupsieve==2.3.2.post1
|
113 |
+
stack-data==0.6.2
|
114 |
+
streamlit==1.11.0
|
115 |
+
terminado==0.13.1
|
116 |
+
tinycss2==1.2.1
|
117 |
+
tokenizers==0.11.4
|
118 |
+
toml==0.10.2
|
119 |
+
toolz==0.12.0
|
120 |
+
torch==1.13.1
|
121 |
+
torchaudio==0.13.1
|
122 |
+
torchvision==0.14.1
|
123 |
+
tornado==6.2
|
124 |
+
tqdm==4.64.1
|
125 |
+
traitlets==5.8.0
|
126 |
+
transformers==4.24.0
|
127 |
+
typing_extensions==4.4.0
|
128 |
+
tzlocal==2.1
|
129 |
+
urllib3==1.26.13
|
130 |
+
validators==0.18.2
|
131 |
+
watchdog==2.1.6
|
132 |
+
wcwidth==0.2.5
|
133 |
+
webencodings==0.5.1
|
134 |
+
websocket-client==0.58.0
|
135 |
+
wheel==0.37.1
|
136 |
+
zipp==3.8.0
|
setup.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/home/appuser/venv/bin/python -m pip install --upgrade pip
|
2 |
+
|
3 |
+
mkdir -p ~/.streamlit/
|
4 |
+
echo "\
|
5 |
+
[server]\n\
|
6 |
+
headless = true\n\
|
7 |
+
port = $PORT\n\
|
8 |
+
enableCORS = false\n\
|
9 |
+
\n\
|
10 |
+
" > ~/.streamlit/config.toml
|