File size: 2,280 Bytes
b247dc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Data loaders."""
import json
import re
import string
from abc import ABC, abstractmethod

from rich.console import Console
from data_utils import read_tables_json
from schema import Table

RE_COLUMN = re.compile(r"^select (.+?) from")
RE_CONDS = re.compile(r"where (.+?)$")
RE_COND = re.compile(r"^(.+?)\s*([=><])\s*(.+?)$")

translator = str.maketrans(
    string.punctuation, " " * len(string.punctuation)
)  # map punctuation to space

console = Console(soft_wrap=True)


def standardize_column(col: str) -> str:
    """Standardize the column name to SQL compatible."""
    col_name = col.replace("#", "num").replace("%", "perc")
    col_name = col_name.strip().lower().translate(translator)
    col_name = re.sub("[^0-9a-z ]", " ", col_name).strip()
    col_name = re.sub(" +", "_", col_name)
    if not col_name:
        console.print(f"original {col}, new {col_name}")
    return col_name


def clean_col(col: str) -> str:
    """Remove table name and standardize column name."""
    if "." in col and not col.endswith("."):
        col = col.split(".")[-1]
    return standardize_column(col)


class Loader(ABC):
    """Loader abstract class."""

    @classmethod
    @abstractmethod
    def load_data(cls, path: str) -> list[dict]:
        """Load data from path."""

    @classmethod
    @abstractmethod
    def load_table_metadata(cls, path: str) -> dict[str, dict[str, Table]]:
        """Extract table metadata from table-metadata-path."""

    @classmethod
    def format_output(cls, prediction: dict) -> dict:
        """Parse for spider format."""
        return prediction


class DefaultLoader(Loader):
    """Spider loader and writer."""

    @classmethod
    def load_data(cls, path: str) -> list[dict]:
        """Load data from path."""
        try:
            with open(path) as f:
                data = json.loads(f.read())
        except json.decoder.JSONDecodeError:
            # Try with jsonl
            data = [json.loads(line) for line in open(path)]
        return data

    @classmethod
    def load_table_metadata(cls, path: str) -> dict[str, dict[str, Table]]:
        """Extract table metadata from table-metadata-path."""
        # load the tables
        db_to_tables = read_tables_json(path, lowercase=True)
        return db_to_tables