Spaces:
Runtime error
Runtime error
import logging | |
import re | |
from typing import Optional, Sequence | |
from clickhouse_connect.datatypes.registry import get_from_name | |
from clickhouse_connect.driver.common import unescape_identifier | |
from clickhouse_connect.driver.exceptions import ProgrammingError | |
from clickhouse_connect.driver import Client | |
from clickhouse_connect.driver.parser import parse_callable | |
from clickhouse_connect.driver.query import remove_sql_comments | |
logger = logging.getLogger(__name__) | |
insert_re = re.compile(r'^\s*INSERT\s+INTO\s+(.*$)', re.IGNORECASE) | |
str_type = get_from_name('String') | |
int_type = get_from_name('Int32') | |
class Cursor: | |
""" | |
See :ref:`https://peps.python.org/pep-0249/` | |
""" | |
def __init__(self, client: Client): | |
self.client = client | |
self.arraysize = 1 | |
self.data: Optional[Sequence] = None | |
self.names = [] | |
self.types = [] | |
self._rowcount = 0 | |
self._ix = 0 | |
def check_valid(self): | |
if self.data is None: | |
raise ProgrammingError('Cursor is not valid') | |
def description(self): | |
return [(n, t, None, None, None, None, True) for n, t in zip(self.names, self.types)] | |
def rowcount(self): | |
return self._rowcount | |
def close(self): | |
self.data = None | |
def execute(self, operation: str, parameters=None): | |
query_result = self.client.query(operation, parameters) | |
self.data = query_result.result_set | |
self._rowcount = len(self.data) | |
if query_result.column_names: | |
self.names = query_result.column_names | |
self.types = [x.name for x in query_result.column_types] | |
elif self.data: | |
self.names = [f'col_{x}' for x in range(len(self.data[0]))] | |
self.types = [x.__class__ for x in self.data[0]] | |
def _try_bulk_insert(self, operation: str, data): | |
match = insert_re.match(remove_sql_comments(operation)) | |
if not match: | |
return False | |
temp = match.group(1) | |
table_end = min(temp.find(' '), temp.find('(')) | |
table = temp[:table_end].strip() | |
temp = temp[table_end:].strip() | |
if temp[0] == '(': | |
_, op_columns, temp = parse_callable(temp) | |
else: | |
op_columns = None | |
if 'VALUES' not in temp.upper(): | |
return False | |
col_names = list(data[0].keys()) | |
if op_columns and {unescape_identifier(x) for x in op_columns} != set(col_names): | |
return False # Data sent in doesn't match the columns in the insert statement | |
data_values = [list(row.values()) for row in data] | |
self.client.insert(table, data_values, col_names) | |
self.data = [] | |
return True | |
def executemany(self, operation, parameters): | |
if not parameters or self._try_bulk_insert(operation, parameters): | |
return | |
self.data = [] | |
try: | |
for param_row in parameters: | |
query_result = self.client.query(operation, param_row) | |
self.data.extend(query_result.result_set) | |
if self.names or self.types: | |
if query_result.column_names != self.names: | |
logger.warning('Inconsistent column names %s : %s for operation %s in cursor executemany', | |
self.names, query_result.column_names, operation) | |
else: | |
self.names = query_result.column_names | |
self.types = query_result.column_types | |
except TypeError as ex: | |
raise ProgrammingError(f'Invalid parameters {parameters} passed to cursor executemany') from ex | |
self._rowcount = len(self.data) | |
def fetchall(self): | |
self.check_valid() | |
ret = self.data | |
self._ix = self._rowcount | |
return ret | |
def fetchone(self): | |
self.check_valid() | |
if self._ix >= self._rowcount: | |
return None | |
val = self.data[self._ix] | |
self._ix += 1 | |
return val | |
def fetchmany(self, size: int = -1): | |
self.check_valid() | |
end = self._ix + max(size, self._rowcount - self._ix) | |
ret = self.data[self._ix: end] | |
self._ix = end | |
return ret | |
def nextset(self): | |
raise NotImplementedError | |
def callproc(self, *args, **kwargs): | |
raise NotImplementedError | |