File size: 2,798 Bytes
b247dc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Training data prep utils."""
import json
import re
from collections import defaultdict
from schema import ForeignKey, Table, TableColumn


def read_tables_json(
    schema_file: str,
    lowercase: bool = False,
) -> dict[str, dict[str, Table]]:
    """Read tables json."""
    data = json.load(open(schema_file))
    db_to_tables = {}
    for db in data:
        db_name = db["db_id"]
        table_names = db["table_names_original"]
        db["column_names_original"] = [
            [x[0], x[1]] for x in db["column_names_original"]
        ]
        db["column_types"] = db["column_types"]
        if lowercase:
            table_names = [tn.lower() for tn in table_names]
        pks = db["primary_keys"]
        fks = db["foreign_keys"]
        tables = defaultdict(list)
        tables_pks = defaultdict(list)
        tables_fks = defaultdict(list)
        for idx, ((ti, col_name), col_type) in enumerate(
            zip(db["column_names_original"], db["column_types"])
        ):
            if ti == -1:
                continue
            if lowercase:
                col_name = col_name.lower()
                col_type = col_type.lower()
            if idx in pks:
                tables_pks[table_names[ti]].append(
                    TableColumn(name=col_name, dtype=col_type)
                )
            for fk in fks:
                if idx == fk[0]:
                    other_column = db["column_names_original"][fk[1]]
                    other_column_type = db["column_types"][fk[1]]
                    other_table = table_names[other_column[0]]
                    tables_fks[table_names[ti]].append(
                        ForeignKey(
                            column=TableColumn(name=col_name, dtype=col_type),
                            references_name=other_table,
                            references_column=TableColumn(
                                name=other_column[1], dtype=other_column_type
                            ),
                        )
                    )
            tables[table_names[ti]].append(TableColumn(name=col_name, dtype=col_type))
        db_to_tables[db_name] = {
            table_name: Table(
                name=table_name,
                columns=tables[table_name],
                pks=tables_pks[table_name],
                fks=tables_fks[table_name],
                examples=None,
            )
            for table_name in tables
        }
    return db_to_tables


def clean_str(target: str) -> str:
    """Clean string for question."""
    if not target:
        return target

    target = re.sub(r"[^\x00-\x7f]", r" ", target)
    line = re.sub(r"''", r" ", target)
    line = re.sub(r"``", r" ", line)
    line = re.sub(r"\"", r"'", line)
    line = re.sub(r"[\t ]+", " ", line)
    return line.strip()