Spaces:
Runtime error
Runtime error
relation-query
#3
by
mpsk
- opened
- .gitignore +0 -2
- app.py +7 -8
- classifier.py +6 -9
- query_model.py +10 -10
- requirements.txt +1 -3
.gitignore
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
.streamlit/
|
2 |
-
__pycache__
|
|
|
|
|
|
app.py
CHANGED
@@ -9,8 +9,7 @@ import logging
|
|
9 |
from os import environ
|
10 |
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
11 |
from bot import Bot, Message
|
12 |
-
from
|
13 |
-
from clickhouse_connect import get_client
|
14 |
from classifier import Classifier, prompt2vec, tune, SplitLayer
|
15 |
from query_model import simple_query, topk_obj_query, rev_query
|
16 |
from card_model import card, obj_card, style
|
@@ -63,11 +62,11 @@ def init_db():
|
|
63 |
client: Database connection object
|
64 |
"""
|
65 |
meta = []
|
66 |
-
|
67 |
-
|
68 |
-
host=r['host'], port=r['port'], user=st.secrets["USER"], password=st.secrets["PASSWD"],
|
69 |
-
interface=r['http_pre'],
|
70 |
)
|
|
|
|
|
71 |
return meta, client
|
72 |
|
73 |
|
@@ -118,7 +117,7 @@ def query(xq, exclude_list=None):
|
|
118 |
IMG_DB_NAME,
|
119 |
OBJ_DB_NAME,
|
120 |
exclude_list=exclude_list,
|
121 |
-
topk=
|
122 |
)
|
123 |
img_ids = [r["img_id"] for r in matches]
|
124 |
if "topk_img_id" not in st.session_state:
|
@@ -141,7 +140,7 @@ def query(xq, exclude_list=None):
|
|
141 |
IMG_DB_NAME,
|
142 |
OBJ_DB_NAME,
|
143 |
thresh=-1,
|
144 |
-
topk=
|
145 |
)
|
146 |
status_bar[0].write("Retrieving Non-TopK in Another TopK Images...")
|
147 |
pbar.progress(75)
|
|
|
9 |
from os import environ
|
10 |
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
11 |
from bot import Bot, Message
|
12 |
+
from myscaledb import Client
|
|
|
13 |
from classifier import Classifier, prompt2vec, tune, SplitLayer
|
14 |
from query_model import simple_query, topk_obj_query, rev_query
|
15 |
from card_model import card, obj_card, style
|
|
|
62 |
client: Database connection object
|
63 |
"""
|
64 |
meta = []
|
65 |
+
client = Client(
|
66 |
+
url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]
|
|
|
|
|
67 |
)
|
68 |
+
# We can check if the connection is alive
|
69 |
+
assert client.is_alive()
|
70 |
return meta, client
|
71 |
|
72 |
|
|
|
117 |
IMG_DB_NAME,
|
118 |
OBJ_DB_NAME,
|
119 |
exclude_list=exclude_list,
|
120 |
+
topk=5000,
|
121 |
)
|
122 |
img_ids = [r["img_id"] for r in matches]
|
123 |
if "topk_img_id" not in st.session_state:
|
|
|
140 |
IMG_DB_NAME,
|
141 |
OBJ_DB_NAME,
|
142 |
thresh=-1,
|
143 |
+
topk=5000,
|
144 |
)
|
145 |
status_bar[0].write("Retrieving Non-TopK in Another TopK Images...")
|
146 |
pbar.progress(75)
|
classifier.py
CHANGED
@@ -95,8 +95,8 @@ class Classifier:
|
|
95 |
grad = []
|
96 |
# Normalize the weight before inference
|
97 |
# This will constrain the gradient or you will have an explosion on query vector
|
98 |
-
self.weight /= torch.norm(
|
99 |
-
self.weight, p=2, dim=-1, keepdim=True
|
100 |
)
|
101 |
for n in range(self.num_class):
|
102 |
# select all training sample and create labels
|
@@ -109,25 +109,22 @@ class Classifier:
|
|
109 |
# To simplify the query, we separated
|
110 |
# the calculation into class numbers
|
111 |
grad_q_str = f"""
|
112 |
-
SELECT
|
113 |
FROM (
|
114 |
SELECT groupArray(arrayPopBack(prelogit)) AS X,
|
115 |
groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT
|
116 |
FROM {self.obj_db} WHERE obj_id IN {objs})"""
|
117 |
-
|
118 |
-
grad.append(torch.as_tensor(grad_))
|
119 |
# update weights
|
120 |
grad = torch.stack(grad, dim=0)
|
121 |
-
self.weight -= 0.
|
122 |
-
self.weight /= torch.norm(
|
123 |
-
self.weight, p=2, dim=-1, keepdim=True
|
124 |
-
)
|
125 |
|
126 |
def get_weights(self):
|
127 |
xq = self.weight.detach().numpy()
|
128 |
return xq
|
129 |
|
130 |
|
|
|
131 |
class SplitLayer(torch.nn.Module):
|
132 |
def forward(self, x):
|
133 |
return torch.split(x, 1, dim=-1)
|
|
|
95 |
grad = []
|
96 |
# Normalize the weight before inference
|
97 |
# This will constrain the gradient or you will have an explosion on query vector
|
98 |
+
self.weight.data /= torch.norm(
|
99 |
+
self.weight.data, p=2, dim=-1, keepdim=True
|
100 |
)
|
101 |
for n in range(self.num_class):
|
102 |
# select all training sample and create labels
|
|
|
109 |
# To simplify the query, we separated
|
110 |
# the calculation into class numbers
|
111 |
grad_q_str = f"""
|
112 |
+
SELECT sumForEachArray(arrayMap((x,y,gt)->arrayMap(i->i*(y-gt), x), X, Y, GT)) AS grad
|
113 |
FROM (
|
114 |
SELECT groupArray(arrayPopBack(prelogit)) AS X,
|
115 |
groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT
|
116 |
FROM {self.obj_db} WHERE obj_id IN {objs})"""
|
117 |
+
grad.append(torch.as_tensor(self.client.fetch(grad_q_str)[0]['grad']))
|
|
|
118 |
# update weights
|
119 |
grad = torch.stack(grad, dim=0)
|
120 |
+
self.weight -= 0.1 * grad
|
|
|
|
|
|
|
121 |
|
122 |
def get_weights(self):
|
123 |
xq = self.weight.detach().numpy()
|
124 |
return xq
|
125 |
|
126 |
|
127 |
+
|
128 |
class SplitLayer(torch.nn.Module):
|
129 |
def forward(self, x):
|
130 |
return torch.split(x, 1, dim=-1)
|
query_model.py
CHANGED
@@ -19,11 +19,11 @@ def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME,
|
|
19 |
FROM {OBJ_DB_NAME}
|
20 |
JOIN {IMG_DB_NAME}
|
21 |
ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
|
22 |
-
|
23 |
SELECT obj_id FROM (
|
24 |
-
SELECT obj_id, distance(prelogit, {_xq}) AS dist FROM {OBJ_DB_NAME}
|
25 |
-
ORDER BY dist DESC
|
26 |
-
) {_cond} LIMIT
|
27 |
)
|
28 |
""")
|
29 |
_subq_str = ' UNION ALL '.join(_subq_str)
|
@@ -38,7 +38,7 @@ def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME,
|
|
38 |
({_subq_str})
|
39 |
GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
|
40 |
"""
|
41 |
-
xc =
|
42 |
return xc
|
43 |
|
44 |
|
@@ -74,7 +74,7 @@ def rev_query(client, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08):
|
|
74 |
({_subq_str})
|
75 |
GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
|
76 |
"""
|
77 |
-
xc =
|
78 |
return xc
|
79 |
|
80 |
|
@@ -88,11 +88,11 @@ def simple_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08, topk=10):
|
|
88 |
subq_str.append(
|
89 |
f"""
|
90 |
SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, prelogit,
|
91 |
-
obj_id, box_cx, box_cy, box_w, box_h, {_l} AS l, distance(prelogit, {_xq}) AS dist
|
92 |
FROM {OBJ_DB_NAME}
|
93 |
JOIN {IMG_DB_NAME}
|
94 |
ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
|
95 |
-
{_thresh}
|
96 |
""")
|
97 |
subq_str = " UNION ALL ".join(subq_str)
|
98 |
q_str = f"""
|
@@ -104,5 +104,5 @@ def simple_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08, topk=10):
|
|
104 |
)
|
105 |
GROUP BY l
|
106 |
"""
|
107 |
-
|
108 |
-
return
|
|
|
19 |
FROM {OBJ_DB_NAME}
|
20 |
JOIN {IMG_DB_NAME}
|
21 |
ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
|
22 |
+
PREWHERE obj_id IN (
|
23 |
SELECT obj_id FROM (
|
24 |
+
SELECT obj_id, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist FROM {OBJ_DB_NAME}
|
25 |
+
ORDER BY dist DESC
|
26 |
+
) {_cond} LIMIT 10
|
27 |
)
|
28 |
""")
|
29 |
_subq_str = ' UNION ALL '.join(_subq_str)
|
|
|
38 |
({_subq_str})
|
39 |
GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
|
40 |
"""
|
41 |
+
xc = client.fetch(q_str)
|
42 |
return xc
|
43 |
|
44 |
|
|
|
74 |
({_subq_str})
|
75 |
GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
|
76 |
"""
|
77 |
+
xc = client.fetch(q_str)
|
78 |
return xc
|
79 |
|
80 |
|
|
|
88 |
subq_str.append(
|
89 |
f"""
|
90 |
SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, prelogit,
|
91 |
+
obj_id, box_cx, box_cy, box_w, box_h, {_l} AS l, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist
|
92 |
FROM {OBJ_DB_NAME}
|
93 |
JOIN {IMG_DB_NAME}
|
94 |
ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
|
95 |
+
{_thresh} LIMIT 10
|
96 |
""")
|
97 |
subq_str = " UNION ALL ".join(subq_str)
|
98 |
q_str = f"""
|
|
|
104 |
)
|
105 |
GROUP BY l
|
106 |
"""
|
107 |
+
res = client.fetch(q_str)
|
108 |
+
return res
|
requirements.txt
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
transformers
|
2 |
tqdm
|
3 |
-
|
4 |
-
parse
|
5 |
streamlit
|
6 |
-
altair < 5
|
7 |
numpy
|
8 |
torch
|
9 |
onnx
|
|
|
1 |
transformers
|
2 |
tqdm
|
3 |
+
myscaledb-client==1.1.7
|
|
|
4 |
streamlit
|
|
|
5 |
numpy
|
6 |
torch
|
7 |
onnx
|