File size: 2,898 Bytes
b247dc4
 
 
 
eec31b9
6da1916
b247dc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6da1916
b247dc4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Text2SQL schemas."""
import enum

from manifest.response import Usage
from pydantic.v1 import BaseModel
from typing import Any, List, Sequence, TypeVar

DEFAULT_TABLE_NAME: str = "db_table"


class Dialect(str, enum.Enum):
    """SQGFluff and SQLGlot dialects.

    Lucky for us, the dialects match both parsers.

    Ref: https://github.com/sqlfluff/sqlfluff/blob/main/src/sqlfluff/core/dialects/__init__.py  # noqa: E501
    Ref: https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py  # noqa: E501
    """

    SNOWFLAKE = "snowflake"
    BIGQUERY = "bigquery"
    REDSHIFT = "redshift"
    POSTGRES = "postgres"
    UNKNOWN = "unknown"

    @property
    def dialect_str(self) -> str | None:
        """Get the dialect string for validation.

        We need to pass in dialect = None for UNKNOWN dialects.
        """
        if self != Dialect.UNKNOWN:
            return self.value
        else:
            return None

    @property
    def quote_str(self) -> str:
        """Get the quote string for the dialect."""
        if self == Dialect.SNOWFLAKE:
            return '"'
        elif self == Dialect.BIGQUERY:
            return "`"
        elif self == Dialect.REDSHIFT:
            return '"'
        elif self == Dialect.POSTGRES:
            return '"'
        elif self == Dialect.UNKNOWN:
            return '"'
        raise NotImplementedError(f"Quote string not implemented for dialect {self}")

    def quote(self, string: str) -> str:
        """Quote a string."""
        return f"{self.quote_str}{string}{self.quote_str}"


class ColumnOrLiteral(BaseModel):
    """Column that may or may not be a literal."""

    name: str | None = None
    literal: bool = False

    def __hash__(self) -> int:
        """Hash."""
        return hash((self.name, self.literal))


class TableColumn(BaseModel):
    """Table column."""

    name: str
    dtype: str | None


class ForeignKey(BaseModel):
    """Foreign key."""

    # Referenced column
    column: TableColumn
    # References table name
    references_name: str
    # References column
    references_column: TableColumn


class Table(BaseModel):
    """Table."""

    name: str | None
    columns: list[TableColumn] | None
    pks: list[TableColumn] | None
    # FK from this table to another column in another table
    fks: list[ForeignKey] | None
    examples: list[dict] | None
    # Is the table a source or intermediate reference table
    is_reference_table: bool = False


class TextToSQLParams(BaseModel):
    """A text to sql request."""

    instruction: str
    database: str | None
    # Default to unknown
    dialect: Dialect = Dialect.UNKNOWN
    tables: list[Table] | None


class TextToSQLModelResponse(BaseModel):
    """Model for Autocomplete Responses."""

    output: str
    final_prompt: str | list[dict]
    raw_output: str
    usage: Any
    metadata: str | None = None