SungBeom's picture
Upload folder using huggingface_hub
4a51346
raw
history blame
4.37 kB
import logging
import re
from typing import Optional, Sequence
from clickhouse_connect.datatypes.registry import get_from_name
from clickhouse_connect.driver.common import unescape_identifier
from clickhouse_connect.driver.exceptions import ProgrammingError
from clickhouse_connect.driver import Client
from clickhouse_connect.driver.parser import parse_callable
from clickhouse_connect.driver.query import remove_sql_comments
logger = logging.getLogger(__name__)
insert_re = re.compile(r'^\s*INSERT\s+INTO\s+(.*$)', re.IGNORECASE)
str_type = get_from_name('String')
int_type = get_from_name('Int32')
class Cursor:
"""
See :ref:`https://peps.python.org/pep-0249/`
"""
def __init__(self, client: Client):
self.client = client
self.arraysize = 1
self.data: Optional[Sequence] = None
self.names = []
self.types = []
self._rowcount = 0
self._ix = 0
def check_valid(self):
if self.data is None:
raise ProgrammingError('Cursor is not valid')
@property
def description(self):
return [(n, t, None, None, None, None, True) for n, t in zip(self.names, self.types)]
@property
def rowcount(self):
return self._rowcount
def close(self):
self.data = None
def execute(self, operation: str, parameters=None):
query_result = self.client.query(operation, parameters)
self.data = query_result.result_set
self._rowcount = len(self.data)
if query_result.column_names:
self.names = query_result.column_names
self.types = [x.name for x in query_result.column_types]
elif self.data:
self.names = [f'col_{x}' for x in range(len(self.data[0]))]
self.types = [x.__class__ for x in self.data[0]]
def _try_bulk_insert(self, operation: str, data):
match = insert_re.match(remove_sql_comments(operation))
if not match:
return False
temp = match.group(1)
table_end = min(temp.find(' '), temp.find('('))
table = temp[:table_end].strip()
temp = temp[table_end:].strip()
if temp[0] == '(':
_, op_columns, temp = parse_callable(temp)
else:
op_columns = None
if 'VALUES' not in temp.upper():
return False
col_names = list(data[0].keys())
if op_columns and {unescape_identifier(x) for x in op_columns} != set(col_names):
return False # Data sent in doesn't match the columns in the insert statement
data_values = [list(row.values()) for row in data]
self.client.insert(table, data_values, col_names)
self.data = []
return True
def executemany(self, operation, parameters):
if not parameters or self._try_bulk_insert(operation, parameters):
return
self.data = []
try:
for param_row in parameters:
query_result = self.client.query(operation, param_row)
self.data.extend(query_result.result_set)
if self.names or self.types:
if query_result.column_names != self.names:
logger.warning('Inconsistent column names %s : %s for operation %s in cursor executemany',
self.names, query_result.column_names, operation)
else:
self.names = query_result.column_names
self.types = query_result.column_types
except TypeError as ex:
raise ProgrammingError(f'Invalid parameters {parameters} passed to cursor executemany') from ex
self._rowcount = len(self.data)
def fetchall(self):
self.check_valid()
ret = self.data
self._ix = self._rowcount
return ret
def fetchone(self):
self.check_valid()
if self._ix >= self._rowcount:
return None
val = self.data[self._ix]
self._ix += 1
return val
def fetchmany(self, size: int = -1):
self.check_valid()
end = self._ix + max(size, self._rowcount - self._ix)
ret = self.data[self._ix: end]
self._ix = end
return ret
def nextset(self):
raise NotImplementedError
def callproc(self, *args, **kwargs):
raise NotImplementedError