File size: 4,955 Bytes
3f1124e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import logging


def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME,
                   exclude_list=[], topk=10):
    xq_s = [
        f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq]
    exclude_list_str = ','.join([f'\'{i}\'' for i in exclude_list])
    _cond = (f"WHERE obj_id NOT IN ({exclude_list_str})" if len(
        exclude_list) > 0 else "")
    _subq_str = []
    _img_score_subq = []
    for _l, _xq in enumerate(xq_s):
        _img_score_subq.append(
            f"arrayReduce('maxIf', logit, arrayMap(x->x={_l}, label))")
        _subq_str.append(f"""
            SELECT img_id, img_url, img_w, img_h, 1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {_xq})))) AS pred_logit, 
                obj_id, box_cx, box_cy, box_w, box_h, class_embedding, {_l} AS l
            FROM {OBJ_DB_NAME}
            JOIN {IMG_DB_NAME}
            ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
            PREWHERE obj_id IN (
                SELECT obj_id FROM (
                    SELECT obj_id, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist FROM {OBJ_DB_NAME}
                    ORDER BY dist DESC
                ) {_cond} LIMIT 10
            )
            """)
    _subq_str = ' UNION ALL '.join(_subq_str)
    _img_score_q = ','.join(_img_score_subq)
    _img_score_q = f"arraySum(arrayFilter(x->NOT isNaN(x), array({_img_score_q}))) AS img_score"
    q_str = f"""
            SELECT img_id, img_url, img_w, img_h, groupArray(obj_id) AS box_id,
                groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
                groupArray(pred_logit) AS logit, groupArray(l) as label, groupArray(class_embedding) AS cls_emb, 
                {_img_score_q}
            FROM 
                    ({_subq_str})
            GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
    """
    xc = client.fetch(q_str)
    return xc


def rev_query(client, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08):
    xq_s = [
        f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq]
    image_list = ','.join([f'\'{i}\'' for i in img_ids])
    _thresh = f"WHERE pred_logit > {thresh}" if thresh > 0 else ""
    _subq_str = []
    _img_score_subq = []
    for _l, _xq in enumerate(xq_s):
        _img_score_subq.append(
            f"arrayReduce('maxIf', logit, arrayMap(x->x={_l}, label))")
        _subq_str.append(f"""
            SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, 
                (1 / (1 + exp(-(arraySum(arrayMap((x,y)->x*y, prelogit, {_xq})))))) AS pred_logit,
                obj_id, box_cx, box_cy, box_w, box_h, class_embedding, {_l} AS l
            FROM {OBJ_DB_NAME} 
                JOIN {IMG_DB_NAME} 
                ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
                PREWHERE img_id IN ({image_list})
                {_thresh}
            """)
    _subq_str = ' UNION ALL '.join(_subq_str)
    _img_score_q = ','.join(_img_score_subq)
    _img_score_q = f"arraySum(arrayFilter(x->NOT isNaN(x), array({_img_score_q}))) AS img_score"
    q_str = f"""
            SELECT img_id, groupArray(obj_id) AS box_id, img_url, img_w, img_h,
                groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
                groupArray(pred_logit) AS logit, groupArray(l) as label, groupArray(class_embedding) AS cls_emb,
                {_img_score_q}
            FROM 
                ({_subq_str})
            GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
    """
    xc = client.fetch(q_str)
    return xc


def simple_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08, topk=10):
    xq_s = [
        f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq]
    res = []
    subq_str = []
    _thresh = f"WHERE pred_logit > {thresh}" if thresh > 0 else ""
    for _l, _xq in enumerate(xq_s):
        subq_str.append(
            f"""
            SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, prelogit,
                obj_id, box_cx, box_cy, box_w, box_h, {_l} AS l,  distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist
            FROM {OBJ_DB_NAME}
            JOIN {IMG_DB_NAME} 
            ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
            {_thresh} LIMIT 10
            """)
    subq_str = " UNION ALL ".join(subq_str)
    q_str = f"""
                SELECT groupArray(img_url) AS img_url, groupArray(img_w) AS img_w, groupArray(img_h) AS img_h,
                groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
                l AS label, groupArray(dist) as d, 
                groupArray(1 / (1 + exp(-dist))) AS logit FROM ( 
                    {subq_str}
                )
                GROUP BY l
                """
    res = client.fetch(q_str)
    return res