File size: 2,852 Bytes
b247dc4
 
 
 
eec31b9
b247dc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""Text2SQL schemas."""
import enum

from manifest.response import Usage
from pydantic.v1 import BaseModel

DEFAULT_TABLE_NAME: str = "db_table"


class Dialect(str, enum.Enum):
    """SQGFluff and SQLGlot dialects.

    Lucky for us, the dialects match both parsers.

    Ref: https://github.com/sqlfluff/sqlfluff/blob/main/src/sqlfluff/core/dialects/__init__.py  # noqa: E501
    Ref: https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py  # noqa: E501
    """

    SNOWFLAKE = "snowflake"
    BIGQUERY = "bigquery"
    REDSHIFT = "redshift"
    POSTGRES = "postgres"
    UNKNOWN = "unknown"

    @property
    def dialect_str(self) -> str | None:
        """Get the dialect string for validation.

        We need to pass in dialect = None for UNKNOWN dialects.
        """
        if self != Dialect.UNKNOWN:
            return self.value
        else:
            return None

    @property
    def quote_str(self) -> str:
        """Get the quote string for the dialect."""
        if self == Dialect.SNOWFLAKE:
            return '"'
        elif self == Dialect.BIGQUERY:
            return "`"
        elif self == Dialect.REDSHIFT:
            return '"'
        elif self == Dialect.POSTGRES:
            return '"'
        elif self == Dialect.UNKNOWN:
            return '"'
        raise NotImplementedError(f"Quote string not implemented for dialect {self}")

    def quote(self, string: str) -> str:
        """Quote a string."""
        return f"{self.quote_str}{string}{self.quote_str}"


class ColumnOrLiteral(BaseModel):
    """Column that may or may not be a literal."""

    name: str | None = None
    literal: bool = False

    def __hash__(self) -> int:
        """Hash."""
        return hash((self.name, self.literal))


class TableColumn(BaseModel):
    """Table column."""

    name: str
    dtype: str | None


class ForeignKey(BaseModel):
    """Foreign key."""

    # Referenced column
    column: TableColumn
    # References table name
    references_name: str
    # References column
    references_column: TableColumn


class Table(BaseModel):
    """Table."""

    name: str | None
    columns: list[TableColumn] | None
    pks: list[TableColumn] | None
    # FK from this table to another column in another table
    fks: list[ForeignKey] | None
    examples: list[dict] | None
    # Is the table a source or intermediate reference table
    is_reference_table: bool = False


class TextToSQLParams(BaseModel):
    """A text to sql request."""

    instruction: str
    database: str | None
    # Default to unknown
    dialect: Dialect = Dialect.UNKNOWN
    tables: list[Table] | None


class TextToSQLModelResponse(BaseModel):
    """Model for Autocomplete Responses."""

    output: str
    final_prompt: str | list[dict]
    raw_output: str
    usage: Usage
    metadata: str | None = None