|
import logging |
|
|
|
|
|
def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, |
|
exclude_list=[], topk=10): |
|
xq_s = [ |
|
f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq] |
|
exclude_list_str = ','.join([f'\'{i}\'' for i in exclude_list]) |
|
_cond = (f"WHERE obj_id NOT IN ({exclude_list_str})" if len( |
|
exclude_list) > 0 else "") |
|
_subq_str = [] |
|
_img_score_subq = [] |
|
for _l, _xq in enumerate(xq_s): |
|
_img_score_subq.append( |
|
f"arrayReduce('maxIf', logit, arrayMap(x->x={_l}, label))") |
|
_subq_str.append(f""" |
|
SELECT img_id, img_url, img_w, img_h, 1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {_xq})))) AS pred_logit, |
|
obj_id, box_cx, box_cy, box_w, box_h, class_embedding, {_l} AS l |
|
FROM {OBJ_DB_NAME} |
|
JOIN {IMG_DB_NAME} |
|
ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id |
|
PREWHERE obj_id IN ( |
|
SELECT obj_id FROM ( |
|
SELECT obj_id, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist FROM {OBJ_DB_NAME} |
|
ORDER BY dist DESC |
|
) {_cond} LIMIT 10 |
|
) |
|
""") |
|
_subq_str = ' UNION ALL '.join(_subq_str) |
|
_img_score_q = ','.join(_img_score_subq) |
|
_img_score_q = f"arraySum(arrayFilter(x->NOT isNaN(x), array({_img_score_q}))) AS img_score" |
|
q_str = f""" |
|
SELECT img_id, img_url, img_w, img_h, groupArray(obj_id) AS box_id, |
|
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h, |
|
groupArray(pred_logit) AS logit, groupArray(l) as label, |
|
{_img_score_q} |
|
FROM |
|
({_subq_str}) |
|
GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC |
|
""" |
|
xc = [{k: v for k, v in r.items()} for r in client.query(q_str).named_results()] |
|
return xc |
|
|
|
|
|
def rev_query(client, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08): |
|
xq_s = [ |
|
f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq] |
|
image_list = ','.join([f'\'{i}\'' for i in img_ids]) |
|
_thresh = f"WHERE pred_logit > {thresh}" if thresh > 0 else "" |
|
_subq_str = [] |
|
_img_score_subq = [] |
|
for _l, _xq in enumerate(xq_s): |
|
_img_score_subq.append( |
|
f"arrayReduce('maxIf', logit, arrayMap(x->x={_l}, label))") |
|
_subq_str.append(f""" |
|
SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, |
|
(1 / (1 + exp(-(arraySum(arrayMap((x,y)->x*y, prelogit, {_xq})))))) AS pred_logit, |
|
obj_id, box_cx, box_cy, box_w, box_h, class_embedding, {_l} AS l |
|
FROM {OBJ_DB_NAME} |
|
JOIN {IMG_DB_NAME} |
|
ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id |
|
PREWHERE img_id IN ({image_list}) |
|
{_thresh} |
|
""") |
|
_subq_str = ' UNION ALL '.join(_subq_str) |
|
_img_score_q = ','.join(_img_score_subq) |
|
_img_score_q = f"arraySum(arrayFilter(x->NOT isNaN(x), array({_img_score_q}))) AS img_score" |
|
q_str = f""" |
|
SELECT img_id, groupArray(obj_id) AS box_id, img_url, img_w, img_h, |
|
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h, |
|
groupArray(pred_logit) AS logit, groupArray(l) as label, |
|
{_img_score_q} |
|
FROM |
|
({_subq_str}) |
|
GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC |
|
""" |
|
xc = [{k: v for k, v in r.items()} for r in client.query(q_str).named_results()] |
|
return xc |
|
|
|
|
|
def simple_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08, topk=10): |
|
xq_s = [ |
|
f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq] |
|
res = [] |
|
subq_str = [] |
|
_thresh = f"WHERE pred_logit > {thresh}" if thresh > 0 else "" |
|
for _l, _xq in enumerate(xq_s): |
|
subq_str.append( |
|
f""" |
|
SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, prelogit, |
|
obj_id, box_cx, box_cy, box_w, box_h, {_l} AS l, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist |
|
FROM {OBJ_DB_NAME} |
|
JOIN {IMG_DB_NAME} |
|
ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id |
|
{_thresh} LIMIT 10 |
|
""") |
|
subq_str = " UNION ALL ".join(subq_str) |
|
q_str = f""" |
|
SELECT groupArray(img_url) AS img_url, groupArray(img_w) AS img_w, groupArray(img_h) AS img_h, |
|
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h, |
|
l AS label, groupArray(dist) as d, |
|
groupArray(1 / (1 + exp(-dist))) AS logit FROM ( |
|
{subq_str} |
|
) |
|
GROUP BY l |
|
""" |
|
xc = [{k: v for k, v in r.items()} for r in client.query(q_str).named_results()] |
|
return xc |
|
|