|
"""Text2SQL schemas.""" |
|
import enum |
|
|
|
from manifest.response import Usage |
|
from pydantic.v1 import BaseModel |
|
from typing import Any, List, Sequence, TypeVar |
|
|
|
DEFAULT_TABLE_NAME: str = "db_table" |
|
|
|
|
|
class Dialect(str, enum.Enum): |
|
"""SQGFluff and SQLGlot dialects. |
|
|
|
Lucky for us, the dialects match both parsers. |
|
|
|
Ref: https://github.com/sqlfluff/sqlfluff/blob/main/src/sqlfluff/core/dialects/__init__.py # noqa: E501 |
|
Ref: https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py # noqa: E501 |
|
""" |
|
|
|
SNOWFLAKE = "snowflake" |
|
BIGQUERY = "bigquery" |
|
REDSHIFT = "redshift" |
|
POSTGRES = "postgres" |
|
UNKNOWN = "unknown" |
|
|
|
@property |
|
def dialect_str(self) -> str | None: |
|
"""Get the dialect string for validation. |
|
|
|
We need to pass in dialect = None for UNKNOWN dialects. |
|
""" |
|
if self != Dialect.UNKNOWN: |
|
return self.value |
|
else: |
|
return None |
|
|
|
@property |
|
def quote_str(self) -> str: |
|
"""Get the quote string for the dialect.""" |
|
if self == Dialect.SNOWFLAKE: |
|
return '"' |
|
elif self == Dialect.BIGQUERY: |
|
return "`" |
|
elif self == Dialect.REDSHIFT: |
|
return '"' |
|
elif self == Dialect.POSTGRES: |
|
return '"' |
|
elif self == Dialect.UNKNOWN: |
|
return '"' |
|
raise NotImplementedError(f"Quote string not implemented for dialect {self}") |
|
|
|
def quote(self, string: str) -> str: |
|
"""Quote a string.""" |
|
return f"{self.quote_str}{string}{self.quote_str}" |
|
|
|
|
|
class ColumnOrLiteral(BaseModel): |
|
"""Column that may or may not be a literal.""" |
|
|
|
name: str | None = None |
|
literal: bool = False |
|
|
|
def __hash__(self) -> int: |
|
"""Hash.""" |
|
return hash((self.name, self.literal)) |
|
|
|
|
|
class TableColumn(BaseModel): |
|
"""Table column.""" |
|
|
|
name: str |
|
dtype: str | None |
|
|
|
|
|
class ForeignKey(BaseModel): |
|
"""Foreign key.""" |
|
|
|
|
|
column: TableColumn |
|
|
|
references_name: str |
|
|
|
references_column: TableColumn |
|
|
|
|
|
class Table(BaseModel): |
|
"""Table.""" |
|
|
|
name: str | None |
|
columns: list[TableColumn] | None |
|
pks: list[TableColumn] | None |
|
|
|
fks: list[ForeignKey] | None |
|
examples: list[dict] | None |
|
|
|
is_reference_table: bool = False |
|
|
|
|
|
class TextToSQLParams(BaseModel): |
|
"""A text to sql request.""" |
|
|
|
instruction: str |
|
database: str | None |
|
|
|
dialect: Dialect = Dialect.UNKNOWN |
|
tables: list[Table] | None |
|
|
|
|
|
class TextToSQLModelResponse(BaseModel): |
|
"""Model for Autocomplete Responses.""" |
|
|
|
output: str |
|
final_prompt: str | list[dict] |
|
raw_output: str |
|
usage: Any |
|
metadata: str | None = None |
|
|