Spaces:
Running
Running
File size: 2,798 Bytes
b247dc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
"""Training data prep utils."""
import json
import re
from collections import defaultdict
from schema import ForeignKey, Table, TableColumn
def read_tables_json(
schema_file: str,
lowercase: bool = False,
) -> dict[str, dict[str, Table]]:
"""Read tables json."""
data = json.load(open(schema_file))
db_to_tables = {}
for db in data:
db_name = db["db_id"]
table_names = db["table_names_original"]
db["column_names_original"] = [
[x[0], x[1]] for x in db["column_names_original"]
]
db["column_types"] = db["column_types"]
if lowercase:
table_names = [tn.lower() for tn in table_names]
pks = db["primary_keys"]
fks = db["foreign_keys"]
tables = defaultdict(list)
tables_pks = defaultdict(list)
tables_fks = defaultdict(list)
for idx, ((ti, col_name), col_type) in enumerate(
zip(db["column_names_original"], db["column_types"])
):
if ti == -1:
continue
if lowercase:
col_name = col_name.lower()
col_type = col_type.lower()
if idx in pks:
tables_pks[table_names[ti]].append(
TableColumn(name=col_name, dtype=col_type)
)
for fk in fks:
if idx == fk[0]:
other_column = db["column_names_original"][fk[1]]
other_column_type = db["column_types"][fk[1]]
other_table = table_names[other_column[0]]
tables_fks[table_names[ti]].append(
ForeignKey(
column=TableColumn(name=col_name, dtype=col_type),
references_name=other_table,
references_column=TableColumn(
name=other_column[1], dtype=other_column_type
),
)
)
tables[table_names[ti]].append(TableColumn(name=col_name, dtype=col_type))
db_to_tables[db_name] = {
table_name: Table(
name=table_name,
columns=tables[table_name],
pks=tables_pks[table_name],
fks=tables_fks[table_name],
examples=None,
)
for table_name in tables
}
return db_to_tables
def clean_str(target: str) -> str:
"""Clean string for question."""
if not target:
return target
target = re.sub(r"[^\x00-\x7f]", r" ", target)
line = re.sub(r"''", r" ", target)
line = re.sub(r"``", r" ", line)
line = re.sub(r"\"", r"'", line)
line = re.sub(r"[\t ]+", " ", line)
return line.strip()
|