Spaces:
Running
Running
File size: 2,898 Bytes
b247dc4 eec31b9 6da1916 b247dc4 6da1916 b247dc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
"""Text2SQL schemas."""
import enum
from manifest.response import Usage
from pydantic.v1 import BaseModel
from typing import Any, List, Sequence, TypeVar
DEFAULT_TABLE_NAME: str = "db_table"
class Dialect(str, enum.Enum):
"""SQGFluff and SQLGlot dialects.
Lucky for us, the dialects match both parsers.
Ref: https://github.com/sqlfluff/sqlfluff/blob/main/src/sqlfluff/core/dialects/__init__.py # noqa: E501
Ref: https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py # noqa: E501
"""
SNOWFLAKE = "snowflake"
BIGQUERY = "bigquery"
REDSHIFT = "redshift"
POSTGRES = "postgres"
UNKNOWN = "unknown"
@property
def dialect_str(self) -> str | None:
"""Get the dialect string for validation.
We need to pass in dialect = None for UNKNOWN dialects.
"""
if self != Dialect.UNKNOWN:
return self.value
else:
return None
@property
def quote_str(self) -> str:
"""Get the quote string for the dialect."""
if self == Dialect.SNOWFLAKE:
return '"'
elif self == Dialect.BIGQUERY:
return "`"
elif self == Dialect.REDSHIFT:
return '"'
elif self == Dialect.POSTGRES:
return '"'
elif self == Dialect.UNKNOWN:
return '"'
raise NotImplementedError(f"Quote string not implemented for dialect {self}")
def quote(self, string: str) -> str:
"""Quote a string."""
return f"{self.quote_str}{string}{self.quote_str}"
class ColumnOrLiteral(BaseModel):
"""Column that may or may not be a literal."""
name: str | None = None
literal: bool = False
def __hash__(self) -> int:
"""Hash."""
return hash((self.name, self.literal))
class TableColumn(BaseModel):
"""Table column."""
name: str
dtype: str | None
class ForeignKey(BaseModel):
"""Foreign key."""
# Referenced column
column: TableColumn
# References table name
references_name: str
# References column
references_column: TableColumn
class Table(BaseModel):
"""Table."""
name: str | None
columns: list[TableColumn] | None
pks: list[TableColumn] | None
# FK from this table to another column in another table
fks: list[ForeignKey] | None
examples: list[dict] | None
# Is the table a source or intermediate reference table
is_reference_table: bool = False
class TextToSQLParams(BaseModel):
"""A text to sql request."""
instruction: str
database: str | None
# Default to unknown
dialect: Dialect = Dialect.UNKNOWN
tables: list[Table] | None
class TextToSQLModelResponse(BaseModel):
"""Model for Autocomplete Responses."""
output: str
final_prompt: str | list[dict]
raw_output: str
usage: Any
metadata: str | None = None
|