Fangrui Liu
commited on
Commit
·
8dda822
1
Parent(s):
88c0383
update to clickhouse python client
Browse files- .gitignore +2 -0
- app.py +5 -5
- classifier.py +2 -2
- query_model.py +4 -4
- requirements.txt +2 -1
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.streamlit/
|
2 |
+
__pycache__
|
app.py
CHANGED
@@ -9,7 +9,8 @@ import logging
|
|
9 |
from os import environ
|
10 |
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
11 |
from bot import Bot, Message
|
12 |
-
from
|
|
|
13 |
from classifier import Classifier, prompt2vec, tune, SplitLayer
|
14 |
from query_model import simple_query, topk_obj_query, rev_query
|
15 |
from card_model import card, obj_card, style
|
@@ -62,11 +63,10 @@ def init_db():
|
|
62 |
client: Database connection object
|
63 |
"""
|
64 |
meta = []
|
65 |
-
|
66 |
-
|
|
|
67 |
)
|
68 |
-
# We can check if the connection is alive
|
69 |
-
assert client.is_alive()
|
70 |
return meta, client
|
71 |
|
72 |
|
|
|
9 |
from os import environ
|
10 |
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
11 |
from bot import Bot, Message
|
12 |
+
from parse import parse
|
13 |
+
from clickhouse_connect import get_client
|
14 |
from classifier import Classifier, prompt2vec, tune, SplitLayer
|
15 |
from query_model import simple_query, topk_obj_query, rev_query
|
16 |
from card_model import card, obj_card, style
|
|
|
63 |
client: Database connection object
|
64 |
"""
|
65 |
meta = []
|
66 |
+
r = parse("{http_pre}://{host}:{port}", st.secrets["DB_URL"])
|
67 |
+
client = get_client(
|
68 |
+
host=r['host'], port=r['port'], user=st.secrets["USER"], password=st.secrets["PASSWD"]
|
69 |
)
|
|
|
|
|
70 |
return meta, client
|
71 |
|
72 |
|
classifier.py
CHANGED
@@ -114,7 +114,8 @@ class Classifier:
|
|
114 |
SELECT groupArray(arrayPopBack(prelogit)) AS X,
|
115 |
groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT
|
116 |
FROM {self.obj_db} WHERE obj_id IN {objs})"""
|
117 |
-
grad
|
|
|
118 |
# update weights
|
119 |
grad = torch.stack(grad, dim=0)
|
120 |
self.weight -= 0.1 * grad
|
@@ -124,7 +125,6 @@ class Classifier:
|
|
124 |
return xq
|
125 |
|
126 |
|
127 |
-
|
128 |
class SplitLayer(torch.nn.Module):
|
129 |
def forward(self, x):
|
130 |
return torch.split(x, 1, dim=-1)
|
|
|
114 |
SELECT groupArray(arrayPopBack(prelogit)) AS X,
|
115 |
groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT
|
116 |
FROM {self.obj_db} WHERE obj_id IN {objs})"""
|
117 |
+
grad_ = [r['grad'] for r in self.client.query(grad_q_str).named_results()][0]
|
118 |
+
grad.append(torch.as_tensor(grad_))
|
119 |
# update weights
|
120 |
grad = torch.stack(grad, dim=0)
|
121 |
self.weight -= 0.1 * grad
|
|
|
125 |
return xq
|
126 |
|
127 |
|
|
|
128 |
class SplitLayer(torch.nn.Module):
|
129 |
def forward(self, x):
|
130 |
return torch.split(x, 1, dim=-1)
|
query_model.py
CHANGED
@@ -38,7 +38,7 @@ def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME,
|
|
38 |
({_subq_str})
|
39 |
GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
|
40 |
"""
|
41 |
-
xc = client.
|
42 |
return xc
|
43 |
|
44 |
|
@@ -74,7 +74,7 @@ def rev_query(client, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08):
|
|
74 |
({_subq_str})
|
75 |
GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
|
76 |
"""
|
77 |
-
xc = client.
|
78 |
return xc
|
79 |
|
80 |
|
@@ -104,5 +104,5 @@ def simple_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08, topk=10):
|
|
104 |
)
|
105 |
GROUP BY l
|
106 |
"""
|
107 |
-
|
108 |
-
return
|
|
|
38 |
({_subq_str})
|
39 |
GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
|
40 |
"""
|
41 |
+
xc = [{k: v for k, v in r.items()} for r in client.query(q_str).named_results()]
|
42 |
return xc
|
43 |
|
44 |
|
|
|
74 |
({_subq_str})
|
75 |
GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
|
76 |
"""
|
77 |
+
xc = [{k: v for k, v in r.items()} for r in client.query(q_str).named_results()]
|
78 |
return xc
|
79 |
|
80 |
|
|
|
104 |
)
|
105 |
GROUP BY l
|
106 |
"""
|
107 |
+
xc = [{k: v for k, v in r.items()} for r in client.query(q_str).named_results()]
|
108 |
+
return xc
|
requirements.txt
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
transformers
|
2 |
tqdm
|
3 |
-
|
|
|
4 |
streamlit
|
5 |
numpy
|
6 |
torch
|
|
|
1 |
transformers
|
2 |
tqdm
|
3 |
+
clickhouse-connect
|
4 |
+
parse
|
5 |
streamlit
|
6 |
numpy
|
7 |
torch
|