mrciolino commited on
Commit
4c728e9
·
1 Parent(s): 9718176

first commit

Browse files
Files changed (16) hide show
  1. Aptfile +5 -0
  2. Procfile +1 -0
  3. README.md +16 -13
  4. app.py +132 -0
  5. images/1.png +0 -0
  6. images/2.png +0 -0
  7. images/3.jpg +0 -0
  8. images/4.jpg +0 -0
  9. images/5.jpg +0 -0
  10. images/6.jpg +0 -0
  11. memory_test.py +58 -0
  12. plot.png +0 -0
  13. refs/baseball.jpg +0 -0
  14. refs/baseball_labeled.png +0 -0
  15. requirements.txt +136 -0
  16. setup.sh +10 -0
Aptfile ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ libsm6
2
+ libgl1
3
+ libxrender1
4
+ libfontconfig1
5
+ libice6
Procfile ADDED
@@ -0,0 +1 @@
 
 
1
+ web: sh setup.sh && streamlit run main.py
README.md CHANGED
@@ -1,13 +1,16 @@
1
- ---
2
- title: Ppt Owl Vit
3
- emoji: 🏃
4
- colorFrom: gray
5
- colorTo: pink
6
- sdk: streamlit
7
- sdk_version: 1.15.2
8
- app_file: app.py
9
- pinned: false
10
- license: cc-by-4.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
+ # OwL-Vit Streamlit App
2
+
3
+ ## Summary
4
+ An application that ultilizes OWL-Vit and Streamlit to understand text to detect objects in images.
5
+
6
+ Deployed and ready to be tested on :
7
+
8
+ [![Open in Streamlit](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://www.matthewciolino.com/)
9
+
10
+ ## Example
11
+
12
+ ![Baseball Field Picture](refs/baseball_labeled.png)
13
+
14
+ ## Made Possible with Adapatation from:
15
+ #### Implementation -> https://huggingface.co/docs/transformers/model_doc/owlvit 🙏
16
+ #### Paper (Yolo) -> https://arxiv.org/abs/2205.06230
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import OwlViTProcessor, OwlViTForObjectDetection
4
+ import warnings
5
+ import numpy as np
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import streamlit as st
9
+ import matplotlib.pyplot as plt
10
+ import io
11
+ import matplotlib.colors as mcolors
12
+
13
+
14
+ # setttings
15
+ os.environ['CUDA_VISIBLE_DEVICES'] = '1'
16
+ warnings.filterwarnings('ignore')
17
+ st.set_page_config()
18
+
19
+
20
+ class owl_vit:
21
+
22
+ def __init__(self, image_path, text, threshold):
23
+ self.image_path = image_path
24
+ self.text = text
25
+ self.threshold = threshold
26
+
27
+ def process(self, processor, model):
28
+ image = Image.open(self.image_path)
29
+ if len(image.split()) == 1:
30
+ image = image.convert("RGB")
31
+ inputs = processor(text=[self.text], images=[image], return_tensors="pt")
32
+ outputs = model(**inputs)
33
+ target_sizes = torch.tensor([[image.height, image.width] for image in [image]])
34
+ self.results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
35
+ self.image = image
36
+ return self.result_image()
37
+
38
+ def result_image(self):
39
+ boxes, scores, labels = self.results[0]["boxes"], self.results[0]["scores"], self.results[0]["labels"]
40
+ plt.imshow(self.image)
41
+ ax = plt.gca()
42
+ for box, score, label in zip(boxes, scores, labels):
43
+ if score >= self.threshold:
44
+ box = box.detach().numpy()
45
+ color = list(mcolors.CSS4_COLORS.keys())[label]
46
+ ax.add_patch(plt.Rectangle(box[:2], box[2] - box[0], box[3] - box[1], fill=False, color=color, linewidth=3,))
47
+ ax.text(box[0], box[1], f"{self.text[label]}: {round(score.item(), 2)}", fontsize=15, color=color)
48
+ plt.tight_layout()
49
+ img_buf = io.BytesIO()
50
+ plt.savefig(img_buf, format='png')
51
+ image = Image.open(img_buf)
52
+ return image
53
+
54
+
55
+ def load_model():
56
+ with st.spinner('Getting Neruons in Order ...'):
57
+ processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch16")
58
+ model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16")
59
+ return processor, model
60
+
61
+
62
+ def show_detects(image):
63
+ st.title("Results")
64
+ st.image(image, use_column_width=True, caption="Object Detection Results", clamp=True)
65
+
66
+ def process(upload, text, threshold):
67
+
68
+ # save upload to file
69
+ filetype = upload.name.split('.')[-1]
70
+ name = len(os.listdir("images")) + 1
71
+ file_path = os.path.join('images', f'{name}.{filetype}')
72
+ with open(file_path, "wb") as f:
73
+ f.write(upload.getbuffer())
74
+
75
+ # predict detections and show results
76
+ detector = owl_vit(file_path, text, threshold)
77
+ results = detector.process(processor, model)
78
+ show_detects(results)
79
+
80
+ # clean up - if over 1000 images in folder, delete oldest 1
81
+ if len(os.listdir("images")) > 1000:
82
+ oldest = min(os.listdir("images"), key=os.path.getctime)
83
+ os.remove(os.path.join("images", oldest))
84
+
85
+
86
+ def main(processor, model):
87
+
88
+ # splash image
89
+ st.image(os.path.join('refs', 'baseball_labeled.png'), use_column_width=True)
90
+
91
+ # title project descriptions
92
+ st.title("OWL-ViT")
93
+ st.markdown("**OWL-ViT** is a zero-shot text-conditioned object detection model. OWL-ViT uses CLIP as its multi-modal \
94
+ backbone, with a ViT-like Transformer to get visual features and a causal language model to get the text features. \
95
+ To use CLIP for detection, OWL-ViT removes the final token pooling layer of the vision model and attaches a \
96
+ lightweight classification and box head to each transformer output token. Open-vocabulary classification \
97
+ is enabled by replacing the fixed classification layer weights with the class-name embeddings obtained \
98
+ from the text model. The authors first train CLIP from scratch and fine-tune it end-to-end with the classification \
99
+ and box heads on standard detection datasets using a bipartite matching loss. One or multiple text queries per image \
100
+ can be used to perform zero-shot text-conditioned object detection.", unsafe_allow_html=True)
101
+
102
+ # example
103
+ if st.button("Run the Example Image/Text"):
104
+ with st.spinner('Detecting Objects and Comparing Vocab...'):
105
+ info = owl_vit(os.path.join('refs', 'baseball.jpg'), ["batter", "umpire", "catcher"], 0.50)
106
+ results = info.process(processor, model)
107
+ show_detects(results)
108
+ if st.button("Clear Example"):
109
+ st.markdown("")
110
+
111
+ # upload
112
+ col1, col2 = st.columns(2)
113
+ threshold = st.slider('Confidence Threshold', min_value=0.0, max_value=1.0, value=0.1)
114
+ with col1:
115
+ upload = st.file_uploader('Image:', type=['jpg', 'jpeg', 'png'])
116
+ with col2:
117
+ text = st.text_area('Objects to Detect: (comma, seperated)', "batter, umpire, catcher")
118
+ text = [x.strip() for x in text.split(',')]
119
+
120
+ # process
121
+ if upload is not None and text is not None:
122
+ filetype = upload.name.split('.')[-1]
123
+ if filetype in ['jpg', 'jpeg', 'png']:
124
+ with st.spinner('Detecting and Counting Single Image...'):
125
+ process(upload, text, threshold)
126
+ else:
127
+ st.warning('Unsupported file type.')
128
+
129
+
130
+ if __name__ == '__main__':
131
+ processor, model = load_model()
132
+ main(processor, model)
images/1.png ADDED
images/2.png ADDED
images/3.jpg ADDED
images/4.jpg ADDED
images/5.jpg ADDED
images/6.jpg ADDED
memory_test.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # USAGE: python run_psrecord.py <PID> --plot plot.png --log activity.txt
2
+
3
+ from psrecord.main import monitor
4
+ import argparse
5
+
6
+
7
+ def main():
8
+ # copied from C:\Users\user\anaconda3\envs\tfod\Lib\site-packages\psrecord\main.py
9
+ parser = argparse.ArgumentParser(
10
+ description='Record CPU and memory usage for a process')
11
+
12
+ parser.add_argument('process_id_or_command', type=str,
13
+ help='the process id or command')
14
+
15
+ parser.add_argument('--log', type=str,
16
+ help='output the statistics to a file')
17
+
18
+ parser.add_argument('--plot', type=str,
19
+ help='output the statistics to a plot')
20
+
21
+ parser.add_argument('--duration', type=float,
22
+ help='how long to record for (in seconds). If not '
23
+ 'specified, the recording is continuous until '
24
+ 'the job exits.')
25
+
26
+ parser.add_argument('--interval', type=float,
27
+ help='how long to wait between each sample (in '
28
+ 'seconds). By default the process is sampled '
29
+ 'as often as possible.')
30
+
31
+ parser.add_argument('--include-children',
32
+ help='include sub-processes in statistics (results '
33
+ 'in a slower maximum sampling rate).',
34
+ action='store_true')
35
+
36
+ args = parser.parse_args()
37
+
38
+ # Attach to process
39
+ try:
40
+ pid = int(args.process_id_or_command)
41
+ print("Attaching to process {0}".format(pid))
42
+ sprocess = None
43
+ except Exception:
44
+ import subprocess
45
+ command = args.process_id_or_command
46
+ print("Starting up command '{0}' and attaching to process"
47
+ .format(command))
48
+ sprocess = subprocess.Popen(command, shell=True)
49
+ pid = sprocess.pid
50
+
51
+ monitor(pid, logfile=args.log, plot=args.plot, duration=args.duration,
52
+ interval=args.interval, include_children=args.include_children)
53
+
54
+ if sprocess is not None:
55
+ sprocess.kill()
56
+
57
+ if __name__ == '__main__':
58
+ main()
plot.png ADDED
refs/baseball.jpg ADDED
refs/baseball_labeled.png ADDED
requirements.txt ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==4.1.0
2
+ anyio==3.5.0
3
+ argon2-cffi==21.3.0
4
+ argon2-cffi-bindings==21.2.0
5
+ asttokens==2.2.1
6
+ attrs==22.1.0
7
+ autopep8==1.6.0
8
+ backcall==0.2.0
9
+ backports.functools-lru-cache==1.6.4
10
+ beautifulsoup4==4.11.1
11
+ bleach==4.1.0
12
+ blinker==1.4
13
+ Bottleneck==1.3.5
14
+ brotlipy==0.7.0
15
+ cachetools==4.2.2
16
+ certifi==2022.9.24
17
+ cffi==1.15.1
18
+ charset-normalizer==2.0.4
19
+ click==8.0.4
20
+ colorama==0.4.5
21
+ comm==0.1.2
22
+ commonmark==0.9.1
23
+ contourpy==1.0.5
24
+ cryptography==38.0.1
25
+ cycler==0.11.0
26
+ debugpy==1.5.1
27
+ decorator==5.1.1
28
+ defusedxml==0.7.1
29
+ entrypoints==0.4
30
+ executing==1.2.0
31
+ fastjsonschema==2.16.2
32
+ filelock==3.6.0
33
+ flit_core==3.6.0
34
+ fonttools==4.25.0
35
+ future==0.18.2
36
+ gitdb==4.0.7
37
+ GitPython==3.1.18
38
+ huggingface-hub==0.10.1
39
+ idna==3.4
40
+ importlib-metadata==4.11.3
41
+ ipykernel==6.19.3
42
+ ipython==8.7.0
43
+ ipython-genutils==0.2.0
44
+ ipywidgets==7.6.5
45
+ jedi==0.18.2
46
+ Jinja2==3.1.2
47
+ jsonschema==4.16.0
48
+ jupyter_client==7.4.8
49
+ jupyter_core==5.1.0
50
+ jupyter-server==1.18.1
51
+ jupyterlab-pygments==0.1.2
52
+ jupyterlab-widgets==1.0.0
53
+ kiwisolver==1.4.2
54
+ lxml==4.9.1
55
+ MarkupSafe==2.1.1
56
+ matplotlib==3.6.2
57
+ matplotlib-inline==0.1.6
58
+ mistune==0.8.4
59
+ mkl-fft==1.3.1
60
+ mkl-random==1.2.2
61
+ mkl-service==2.4.0
62
+ munkres==1.1.4
63
+ nbclassic==0.4.8
64
+ nbclient==0.5.13
65
+ nbconvert==6.5.4
66
+ nbformat==5.7.0
67
+ nest-asyncio==1.5.6
68
+ notebook==6.5.2
69
+ notebook_shim==0.2.2
70
+ numexpr==2.8.4
71
+ numpy==1.22.3
72
+ packaging==21.3
73
+ pandas==1.5.2
74
+ pandocfilters==1.5.0
75
+ parso==0.8.3
76
+ pickleshare==0.7.5
77
+ Pillow==9.3.0
78
+ pip==22.3.1
79
+ platformdirs==2.6.0
80
+ ply==3.11
81
+ prometheus-client==0.14.1
82
+ prompt-toolkit==3.0.36
83
+ protobuf==3.20.1
84
+ psutil==5.9.0
85
+ pure-eval==0.2.2
86
+ pyarrow==8.0.0
87
+ pycodestyle==2.8.0
88
+ pycparser==2.21
89
+ pydeck==0.7.1
90
+ Pygments==2.13.0
91
+ Pympler==0.9
92
+ pyOpenSSL==22.0.0
93
+ pyparsing==3.0.9
94
+ PyQt5==5.15.7
95
+ PyQt5-sip==12.11.0
96
+ pyrsistent==0.18.0
97
+ PySocks==1.7.1
98
+ python-dateutil==2.8.2
99
+ pytz==2022.1
100
+ PyYAML==6.0
101
+ pyzmq==23.2.0
102
+ regex==2022.7.9
103
+ requests==2.28.1
104
+ rich==12.5.1
105
+ semver==2.13.0
106
+ Send2Trash==1.8.0
107
+ setuptools==65.5.0
108
+ sip==6.6.2
109
+ six==1.16.0
110
+ smmap==4.0.0
111
+ sniffio==1.2.0
112
+ soupsieve==2.3.2.post1
113
+ stack-data==0.6.2
114
+ streamlit==1.11.0
115
+ terminado==0.13.1
116
+ tinycss2==1.2.1
117
+ tokenizers==0.11.4
118
+ toml==0.10.2
119
+ toolz==0.12.0
120
+ torch==1.13.1
121
+ torchaudio==0.13.1
122
+ torchvision==0.14.1
123
+ tornado==6.2
124
+ tqdm==4.64.1
125
+ traitlets==5.8.0
126
+ transformers==4.24.0
127
+ typing_extensions==4.4.0
128
+ tzlocal==2.1
129
+ urllib3==1.26.13
130
+ validators==0.18.2
131
+ watchdog==2.1.6
132
+ wcwidth==0.2.5
133
+ webencodings==0.5.1
134
+ websocket-client==0.58.0
135
+ wheel==0.37.1
136
+ zipp==3.8.0
setup.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ /home/appuser/venv/bin/python -m pip install --upgrade pip
2
+
3
+ mkdir -p ~/.streamlit/
4
+ echo "\
5
+ [server]\n\
6
+ headless = true\n\
7
+ port = $PORT\n\
8
+ enableCORS = false\n\
9
+ \n\
10
+ " > ~/.streamlit/config.toml