|
from time import time |
|
from io import BytesIO |
|
import torch |
|
import streamlit as st |
|
import streamlit.components.v1 as components |
|
import numpy as np |
|
import torch |
|
import logging |
|
from os import environ |
|
from transformers import OwlViTProcessor, OwlViTForObjectDetection |
|
from bot import Bot, Message |
|
from parse import parse |
|
from clickhouse_connect import get_client |
|
from classifier import Classifier, prompt2vec, tune, SplitLayer |
|
from query_model import simple_query, topk_obj_query, rev_query |
|
from card_model import card, obj_card, style |
|
from box_utils import postprocess |
|
|
|
environ["TOKENIZERS_PARALLELISM"] = "true" |
|
|
|
OBJ_DB_NAME = "mqdb_demo.coco_owl_vit_b_32_objects" |
|
IMG_DB_NAME = "mqdb_demo.coco_owl_vit_b_32_images" |
|
MODEL_ID = "google/owlvit-base-patch32" |
|
DIMS = 512 |
|
|
|
qtime = 0 |
|
|
|
|
|
def build_model(name="google/owlvit-base-patch32"): |
|
"""Model builder function |
|
|
|
Args: |
|
name (str, optional): Name for HuggingFace OwlViT model. Defaults to "google/owlvit-base-patch32". |
|
|
|
Returns: |
|
(model, processor): OwlViT model and its processor for both image and text |
|
""" |
|
device = "cpu" |
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
model = OwlViTForObjectDetection.from_pretrained(name).to(device) |
|
processor = OwlViTProcessor.from_pretrained(name) |
|
return model, processor |
|
|
|
|
|
@st.experimental_singleton(show_spinner=False) |
|
def init_owlvit(): |
|
"""Initialize OwlViT Model |
|
|
|
Returns: |
|
model, processor |
|
""" |
|
model, processor = build_model(MODEL_ID) |
|
return model, processor |
|
|
|
|
|
@st.experimental_singleton(show_spinner=False) |
|
def init_db(): |
|
"""Initialize the Database Connection |
|
|
|
Returns: |
|
meta_field: Meta field that records if an image is viewed or not |
|
client: Database connection object |
|
""" |
|
meta = [] |
|
r = parse("{http_pre}://{host}:{port}", st.secrets["DB_URL"]) |
|
client = get_client( |
|
host=r['host'], port=r['port'], user=st.secrets["USER"], password=st.secrets["PASSWD"] |
|
) |
|
return meta, client |
|
|
|
|
|
def refresh_index(): |
|
"""Clean the session""" |
|
del st.session_state["meta"] |
|
st.session_state.meta = [] |
|
st.session_state.query_num = 0 |
|
logging.info(f"Refresh for '{st.session_state.meta}'") |
|
|
|
init_db.clear() |
|
|
|
st.session_state.meta, st.session_state.index = init_db() |
|
if "clf" in st.session_state: |
|
del st.session_state.clf |
|
if "xq" in st.session_state: |
|
del st.session_state.xq |
|
if "topk_img_id" in st.session_state: |
|
del st.session_state.topk_img_id |
|
|
|
|
|
def query(xq, exclude_list=None): |
|
"""Query matched w.r.t a given vector |
|
|
|
In this part, we will retrieve A LOT OF data from the server, |
|
including TopK boxes and their embeddings, the counterpart of non-TopK boxes in TopK images. |
|
|
|
Args: |
|
xq (numpy.ndarray or list of floats): Query vector |
|
|
|
Returns: |
|
matches: list of Records object. Keys referrring to selected columns group by images. |
|
Exclude the user's viewlist. |
|
img_matches: list of Records object. Containing other non-TopK but hit objects among TopK images. |
|
side_matches: list of Records object. Containing REAL TopK objects disregard the user's view history |
|
""" |
|
attempt = 0 |
|
xq = xq |
|
xq = xq / np.linalg.norm(xq, axis=-1, ord=2, keepdims=True) |
|
status_bar = [st.empty(), st.empty()] |
|
status_bar[0].write("Retrieving Another TopK Images...") |
|
pbar = status_bar[1].progress(0) |
|
while attempt < 3: |
|
try: |
|
matches = topk_obj_query( |
|
st.session_state.index, |
|
xq, |
|
IMG_DB_NAME, |
|
OBJ_DB_NAME, |
|
exclude_list=exclude_list, |
|
topk=5000, |
|
) |
|
img_ids = [r["img_id"] for r in matches] |
|
if "topk_img_id" not in st.session_state: |
|
st.session_state.topk_img_id = img_ids |
|
status_bar[0].write("Retrieving TopK Images...") |
|
pbar.progress(25) |
|
o_matches = rev_query( |
|
st.session_state.index, |
|
xq, |
|
st.session_state.topk_img_id, |
|
IMG_DB_NAME, |
|
OBJ_DB_NAME, |
|
thresh=0.1, |
|
) |
|
status_bar[0].write("Retrieving TopKs Objects...") |
|
pbar.progress(50) |
|
side_matches = simple_query( |
|
st.session_state.index, |
|
xq, |
|
IMG_DB_NAME, |
|
OBJ_DB_NAME, |
|
thresh=-1, |
|
topk=5000, |
|
) |
|
status_bar[0].write("Retrieving Non-TopK in Another TopK Images...") |
|
pbar.progress(75) |
|
if len(img_ids) > 0: |
|
img_matches = rev_query( |
|
st.session_state.index, |
|
xq, |
|
img_ids, |
|
IMG_DB_NAME, |
|
OBJ_DB_NAME, |
|
thresh=0.1, |
|
) |
|
else: |
|
img_matches = [] |
|
status_bar[0].write("DONE!") |
|
pbar.progress(100) |
|
break |
|
except Exception as e: |
|
|
|
logging.warning(str(e)) |
|
st.session_state.meta, st.session_state.index = init_db() |
|
attempt += 1 |
|
matches = [] |
|
_ = [s.empty() for s in status_bar] |
|
if len(matches) == 0: |
|
logging.error(f"No matches found for '{OBJ_DB_NAME}'") |
|
return matches, img_matches, side_matches, o_matches |
|
|
|
|
|
@st.experimental_singleton(show_spinner=False) |
|
def init_random_query(): |
|
"""Initialize a random query vector |
|
|
|
Returns: |
|
xq: a random vector |
|
""" |
|
xq = np.random.rand(1, DIMS) |
|
xq /= np.linalg.norm(xq, keepdims=True, axis=-1) |
|
return xq |
|
|
|
|
|
def submit(meta): |
|
"""Tune the model w.r.t given score from user.""" |
|
|
|
st.session_state.meta.extend(meta) |
|
st.session_state.step += 1 |
|
matches = st.session_state.matched_boxes |
|
X, y = list( |
|
zip( |
|
*( |
|
( |
|
v[0], |
|
st.session_state.text_prompts.index(st.session_state[f"label-{i}"]), |
|
) |
|
for i, v in matches.items() |
|
) |
|
) |
|
) |
|
st.session_state.xq = tune( |
|
st.session_state.clf, X, y, iters=int(st.session_state.iters) |
|
) |
|
( |
|
st.session_state.matches, |
|
st.session_state.img_matches, |
|
st.session_state.side_matches, |
|
st.session_state.o_matches, |
|
) = query(st.session_state.xq, st.session_state.meta) |
|
|
|
|
|
|
|
|
|
|
|
st.write(style(), unsafe_allow_html=True) |
|
|
|
bot = Bot(app_name="HF OwlViT", enabled=True, bot_key=st.secrets['BOT_KEY']) |
|
try: |
|
with st.spinner("Connecting DB..."): |
|
st.session_state.meta, st.session_state.index = init_db() |
|
|
|
with st.spinner("Loading Models..."): |
|
|
|
model, tokenizer = init_owlvit() |
|
|
|
if "xq" not in st.session_state: |
|
with st.container(): |
|
st.title("Object Detection Safari") |
|
start = [st.empty() for _ in range(8)] |
|
start[0].info( |
|
""" |
|
We extracted boxes from **287,104** images in COCO Dataset, including its train / val / test / |
|
unlabeled images, collecting **165,371,904 boxes** which are then filtered with common prompts. |
|
You can search with almost any words or phrases you can think of. Please enjoy your journey of |
|
an adventure to COCO. |
|
""" |
|
) |
|
prompt = start[1].text_input( |
|
"Prompt:", |
|
value="", |
|
placeholder="Examples: football, billboard, stop sign, watermark ...", |
|
) |
|
with start[2].container(): |
|
st.write( |
|
"You can search with multiple keywords. Plese separate with commas but with no space." |
|
) |
|
st.write("For example: `cat,dog,tree`") |
|
st.markdown( |
|
""" |
|
<p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
upld_model = start[4].file_uploader( |
|
"Or you can upload your previous run!", type="onnx" |
|
) |
|
upld_btn = start[5].button( |
|
"Use Loaded Weights", disabled=upld_model is None, on_click=refresh_index |
|
) |
|
|
|
with start[3]: |
|
col = st.columns(8) |
|
has_no_prompt = len(prompt) == 0 and upld_model is None |
|
prompt_xq = col[6].button( |
|
"Prompt", disabled=len(prompt) == 0, on_click=refresh_index |
|
) |
|
random_xq = col[7].button( |
|
"Random", disabled=not has_no_prompt, on_click=refresh_index |
|
) |
|
matches = [] |
|
img_matches = [] |
|
if random_xq: |
|
xq = init_random_query() |
|
st.session_state.xq = xq |
|
prompt = "unknown" |
|
st.session_state.text_prompts = prompt.split(",") + ["none"] |
|
_ = [elem.empty() for elem in start] |
|
t0 = time() |
|
( |
|
st.session_state.matches, |
|
st.session_state.img_matches, |
|
st.session_state.side_matches, |
|
st.session_state.o_matches, |
|
) = query(st.session_state.xq, st.session_state.meta) |
|
t1 = time() |
|
qtime = (t1 - t0) * 1000 |
|
elif prompt_xq or upld_btn: |
|
if upld_model is not None: |
|
import onnx |
|
from onnx import numpy_helper |
|
|
|
_model = onnx.load(upld_model) |
|
st.session_state.text_prompts = [ |
|
node.name for node in _model.graph.output |
|
] + ["none"] |
|
weights = _model.graph.initializer |
|
xq = numpy_helper.to_array(weights[0]).T |
|
assert ( |
|
xq.shape[0] == len(st.session_state.text_prompts) - 1 |
|
and xq.shape[1] == DIMS |
|
) |
|
st.session_state.xq = xq |
|
_ = [elem.empty() for elem in start] |
|
else: |
|
logging.info(f"Input prompt is {prompt}") |
|
st.session_state.text_prompts = prompt.split(",") + ["none"] |
|
input_ids, xq = prompt2vec( |
|
st.session_state.text_prompts[:-1], model, tokenizer |
|
) |
|
st.session_state.xq = xq |
|
_ = [elem.empty() for elem in start] |
|
t0 = time() |
|
( |
|
st.session_state.matches, |
|
st.session_state.img_matches, |
|
st.session_state.side_matches, |
|
st.session_state.o_matches, |
|
) = query(st.session_state.xq, st.session_state.meta) |
|
t1 = time() |
|
qtime = (t1 - t0) * 1000 |
|
|
|
|
|
if "xq" in st.session_state: |
|
o_matches = st.session_state.o_matches |
|
side_matches = st.session_state.side_matches |
|
img_matches = st.session_state.img_matches |
|
matches = st.session_state.matches |
|
|
|
if "clf" not in st.session_state: |
|
st.session_state.clf = Classifier(st.session_state.index, OBJ_DB_NAME, st.session_state.xq) |
|
st.session_state.step = 0 |
|
if qtime > 0: |
|
st.info( |
|
"Query done in {0:.2f} ms and returned {1:d} images with {2:d} boxes".format( |
|
qtime, |
|
len(matches), |
|
sum( |
|
[ |
|
len(m["box_id"]) + len(im["box_id"]) |
|
for m, im in zip(matches, img_matches) |
|
] |
|
), |
|
) |
|
) |
|
lnprob = torch.nn.Linear(st.session_state.xq.shape[1], st.session_state.xq.shape[0], bias=False) |
|
lnprob.weight = torch.nn.Parameter(st.session_state.clf.weight) |
|
|
|
|
|
st.session_state.dnld_model = BytesIO() |
|
torch.onnx.export( |
|
torch.nn.Sequential(lnprob, SplitLayer()), |
|
torch.zeros([1, len(st.session_state.xq[0])]), |
|
st.session_state.dnld_model, |
|
input_names=["input"], |
|
output_names=st.session_state.text_prompts[:-1], |
|
) |
|
|
|
dnld_nam = st.text_input( |
|
"Download Name:", |
|
f'{("_".join([i.replace(" ", "-") for i in st.session_state.text_prompts[:-1]]) if "text_prompts" in st.session_state else "model")}.onnx', |
|
max_chars=50, |
|
) |
|
dnld_btn = st.download_button( |
|
"Download your classifier!", st.session_state.dnld_model, dnld_nam |
|
) |
|
|
|
|
|
side_bar_len = min(240 // len(st.session_state.text_prompts), 120) |
|
with st.sidebar: |
|
with st.expander("Top-K Images"): |
|
with st.container(): |
|
boxes_w_img, _ = postprocess( |
|
o_matches, st.session_state.text_prompts, o_matches, |
|
agnostic_ratio=1-0.6**(st.session_state.step+1), |
|
class_ratio=1-0.2**(st.session_state.step+1) |
|
) |
|
boxes_w_img = sorted(boxes_w_img, key=lambda x: x[4], reverse=True) |
|
for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img: |
|
args = img_url, img_w, img_h, boxes |
|
st.write(card(*args), unsafe_allow_html=True) |
|
|
|
with st.expander("Top-K Objects", expanded=True): |
|
side_cols = st.columns(len(st.session_state.text_prompts[:-1])) |
|
for _cols, m in zip(side_cols, side_matches): |
|
with _cols.container(): |
|
for cx, cy, w, h, logit, img_url, img_w, img_h in zip( |
|
m["cx"], |
|
m["cy"], |
|
m["w"], |
|
m["h"], |
|
m["logit"], |
|
m["img_url"], |
|
m["img_w"], |
|
m["img_h"], |
|
): |
|
st.write( |
|
"{:s}: {:.4f}".format( |
|
st.session_state.text_prompts[m["label"]], logit |
|
) |
|
) |
|
_html = obj_card( |
|
img_url, img_w, img_h, cx, cy, w, h, dst_len=side_bar_len |
|
) |
|
components.html(_html, side_bar_len, side_bar_len) |
|
with st.container(): |
|
|
|
with st.form("batch", clear_on_submit=False): |
|
col = st.columns([1, 9]) |
|
|
|
|
|
if len(matches) <= 0: |
|
st.warning( |
|
"Oops! We didn't find anything relevant to your query! Pleas try another one :/" |
|
) |
|
else: |
|
st.session_state.iters = st.slider( |
|
"Number of Iterations to Update", |
|
min_value=0, |
|
max_value=10, |
|
step=1, |
|
value=2, |
|
) |
|
|
|
col[1].form_submit_button("Choose a new prompt", on_click=refresh_index) |
|
|
|
|
|
if len(matches) > 0: |
|
with st.container(): |
|
prompt_labels = st.session_state.text_prompts |
|
|
|
|
|
boxes_w_img, meta = postprocess( |
|
matches, st.session_state.text_prompts, img_matches, |
|
agnostic_ratio=1-0.6**(st.session_state.step+1), |
|
class_ratio=1-0.2**(st.session_state.step+1) |
|
) |
|
|
|
|
|
boxes_w_img = sorted(boxes_w_img, key=lambda x: x[4], reverse=True) |
|
|
|
st.session_state.matched_boxes = {} |
|
|
|
for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img: |
|
|
|
|
|
st.session_state.matched_boxes.update({b[0]: b for b in boxes}) |
|
args = img_url, img_w, img_h, boxes |
|
|
|
|
|
with st.expander( |
|
"{:s}: {:.4f}".format(img_id, img_score), expanded=True |
|
): |
|
ind_b = 0 |
|
|
|
img_row = st.columns([4, 2, 2, 2]) |
|
img_row[0].write(card(*args), unsafe_allow_html=True) |
|
|
|
for b in boxes: |
|
_id, cx, cy, w, h, label, logit, is_selected = b[:8] |
|
with img_row[1 + ind_b % 3].container(): |
|
st.write("{:s}: {:.4f}".format(label, logit)) |
|
|
|
_html = obj_card( |
|
img_url, img_w, img_h, *b[1:5], dst_len=120 |
|
) |
|
components.html(_html, 120, 120) |
|
|
|
st.selectbox( |
|
"Class", |
|
prompt_labels, |
|
index=prompt_labels.index(label), |
|
key=f"label-{_id}", |
|
) |
|
ind_b += 1 |
|
col[0].form_submit_button("Train!", on_click=lambda: submit(meta)) |
|
except Exception as e: |
|
msg = Message() |
|
msg.content = str(e.with_traceback(None)) |
|
msg.type_hint = str(type(e).__name__) |
|
bot.incident(msg) |
|
|