Fangrui Liu commited on
Commit
8dda822
·
1 Parent(s): 88c0383

update to clickhouse python client

Browse files
Files changed (5) hide show
  1. .gitignore +2 -0
  2. app.py +5 -5
  3. classifier.py +2 -2
  4. query_model.py +4 -4
  5. requirements.txt +2 -1
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .streamlit/
2
+ __pycache__
app.py CHANGED
@@ -9,7 +9,8 @@ import logging
9
  from os import environ
10
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
11
  from bot import Bot, Message
12
- from myscaledb import Client
 
13
  from classifier import Classifier, prompt2vec, tune, SplitLayer
14
  from query_model import simple_query, topk_obj_query, rev_query
15
  from card_model import card, obj_card, style
@@ -62,11 +63,10 @@ def init_db():
62
  client: Database connection object
63
  """
64
  meta = []
65
- client = Client(
66
- url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]
 
67
  )
68
- # We can check if the connection is alive
69
- assert client.is_alive()
70
  return meta, client
71
 
72
 
 
9
  from os import environ
10
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
11
  from bot import Bot, Message
12
+ from parse import parse
13
+ from clickhouse_connect import get_client
14
  from classifier import Classifier, prompt2vec, tune, SplitLayer
15
  from query_model import simple_query, topk_obj_query, rev_query
16
  from card_model import card, obj_card, style
 
63
  client: Database connection object
64
  """
65
  meta = []
66
+ r = parse("{http_pre}://{host}:{port}", st.secrets["DB_URL"])
67
+ client = get_client(
68
+ host=r['host'], port=r['port'], user=st.secrets["USER"], password=st.secrets["PASSWD"]
69
  )
 
 
70
  return meta, client
71
 
72
 
classifier.py CHANGED
@@ -114,7 +114,8 @@ class Classifier:
114
  SELECT groupArray(arrayPopBack(prelogit)) AS X,
115
  groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT
116
  FROM {self.obj_db} WHERE obj_id IN {objs})"""
117
- grad.append(torch.as_tensor(self.client.fetch(grad_q_str)[0]['grad']))
 
118
  # update weights
119
  grad = torch.stack(grad, dim=0)
120
  self.weight -= 0.1 * grad
@@ -124,7 +125,6 @@ class Classifier:
124
  return xq
125
 
126
 
127
-
128
  class SplitLayer(torch.nn.Module):
129
  def forward(self, x):
130
  return torch.split(x, 1, dim=-1)
 
114
  SELECT groupArray(arrayPopBack(prelogit)) AS X,
115
  groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT
116
  FROM {self.obj_db} WHERE obj_id IN {objs})"""
117
+ grad_ = [r['grad'] for r in self.client.query(grad_q_str).named_results()][0]
118
+ grad.append(torch.as_tensor(grad_))
119
  # update weights
120
  grad = torch.stack(grad, dim=0)
121
  self.weight -= 0.1 * grad
 
125
  return xq
126
 
127
 
 
128
  class SplitLayer(torch.nn.Module):
129
  def forward(self, x):
130
  return torch.split(x, 1, dim=-1)
query_model.py CHANGED
@@ -38,7 +38,7 @@ def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME,
38
  ({_subq_str})
39
  GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
40
  """
41
- xc = client.fetch(q_str)
42
  return xc
43
 
44
 
@@ -74,7 +74,7 @@ def rev_query(client, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08):
74
  ({_subq_str})
75
  GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
76
  """
77
- xc = client.fetch(q_str)
78
  return xc
79
 
80
 
@@ -104,5 +104,5 @@ def simple_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08, topk=10):
104
  )
105
  GROUP BY l
106
  """
107
- res = client.fetch(q_str)
108
- return res
 
38
  ({_subq_str})
39
  GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
40
  """
41
+ xc = [{k: v for k, v in r.items()} for r in client.query(q_str).named_results()]
42
  return xc
43
 
44
 
 
74
  ({_subq_str})
75
  GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
76
  """
77
+ xc = [{k: v for k, v in r.items()} for r in client.query(q_str).named_results()]
78
  return xc
79
 
80
 
 
104
  )
105
  GROUP BY l
106
  """
107
+ xc = [{k: v for k, v in r.items()} for r in client.query(q_str).named_results()]
108
+ return xc
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  transformers
2
  tqdm
3
- myscaledb-client==1.1.7
 
4
  streamlit
5
  numpy
6
  torch
 
1
  transformers
2
  tqdm
3
+ clickhouse-connect
4
+ parse
5
  streamlit
6
  numpy
7
  torch