import json import os import torch from datasets import Dataset from transformers import ( AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, ) from torch.utils.data import DataLoader from sklearn.model_selection import train_test_split from tqdm import tqdm def load_table_schemas(tables_file): """ Load table schemas from the tables.jsonl file. Args: tables_file: Path to the tables.jsonl file. Returns: A dictionary mapping table IDs to their column names. """ table_schemas = {} with open(tables_file, 'r') as f: for line in f: table_data = json.loads(line) table_id = table_data["id"] table_columns = table_data["header"] table_schemas[table_id] = table_columns return table_schemas # Step 1: Load and Preprocess WikiSQL Data def load_wikisql(data_dir): """ Load WikiSQL data and prepare it for training. Args: data_dir: Path to the WikiSQL dataset directory. Returns: List of examples with input and target text. """ def parse_file(file_path): with open(file_path, 'r') as f: return [json.loads(line) for line in f] tables_data = parse_file(os.path.join(data_dir, "train.tables.jsonl")) train_data = parse_file(os.path.join(data_dir, "train.jsonl")) dev_data = parse_file(os.path.join(data_dir, "dev.jsonl")) print("====>", train_data[0]) tables_file = "./data/train.tables.jsonl" table_schemas = load_table_schemas(tables_file) dev_tables = './data/dev.tables.jsonl' dev_tables_schema = load_table_schemas(dev_tables) def format_data(data, type): formatted = [] for item in data: table_id = item["table_id"] table_columns = table_schemas[table_id] if type == 'train' else dev_tables_schema[table_id] question = item["question"] sql = item["sql"] sql_query = sql_to_text(sql, table_columns) print("SQL Query", sql_query) formatted.append({"input": f"Question: {question}", "target": sql_query}) return formatted return format_data(train_data, "train"), format_data(dev_data, "dev") def sql_to_text(sql, table_columns): """ Convert SQL dictionary from WikiSQL to text representation. Args: sql: SQL dictionary from WikiSQL (e.g., {"sel": 5, "conds": [[3, 0, "value"]], "agg": 0}). table_columns: List of column names corresponding to the table. Returns: SQL query as a string. """ # Aggregation functions mapping agg_functions = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"] operators = ["=", ">", "<"] # Get selected column sel_column = table_columns[sql["sel"]] agg_func = agg_functions[sql["agg"]] select_clause = f"SELECT {agg_func}({sel_column})" if agg_func else f"SELECT {sel_column}" # Get conditions if sql["conds"]: conditions = [] for cond in sql["conds"]: col_idx, operator, value = cond col_name = table_columns[col_idx] conditions.append(f"{col_name} {operators[operator]} '{value}'") where_clause = " WHERE " + " AND ".join(conditions) else: where_clause = "" # Combine clauses into a full query return select_clause + where_clause # Step 2: Tokenize the Data def tokenize_data(data, tokenizer, max_length=128): """ Tokenize the input and target text. Args: data: List of examples with "input" and "target". tokenizer: Pretrained tokenizer. max_length: Maximum sequence length for the model. Returns: Tokenized dataset. """ inputs = [item["input"] for item in data] targets = [item["target"] for item in data] tokenized = tokenizer( inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", ) labels = tokenizer( targets, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", ) tokenized["labels"] = labels["input_ids"] return tokenized # Step 3: Load Model and Tokenizer model_name = "t5-small" # Use "t5-small", "t5-base", or "t5-large" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # Step 4: Prepare Training and Validation Data data_dir = "data" # Path to the WikiSQL dataset train_data, dev_data = load_wikisql(data_dir) # Tokenize Data train_dataset = tokenize_data(train_data, tokenizer) dev_dataset = tokenize_data(dev_data, tokenizer) # # Convert to Hugging Face Dataset format train_dataset = Dataset.from_dict(train_dataset) dev_dataset = Dataset.from_dict(dev_dataset) # # # Step 5: Define Training Arguments # training_args = Seq2SeqTrainingArguments( # output_dir="./t5_sql_finetuned", # evaluation_strategy="steps", # save_steps=1000, # eval_steps=100, # logging_steps=100, # per_device_train_batch_size=16, # per_device_eval_batch_size=16, # num_train_epochs=3, # save_total_limit=2, # learning_rate=5e-5, # predict_with_generate=True, # fp16=torch.cuda.is_available(), # Enable mixed precision for faster training # logging_dir="./logs", # ) # # # Step 6: Define Trainer # trainer = Seq2SeqTrainer( # model=model, # args=training_args, # train_dataset=train_dataset, # eval_dataset=dev_dataset, # tokenizer=tokenizer, # ) # # # Step 7: Train the Model # trainer.train() # # # Step 8: Save the Model # trainer.save_model("./t5_sql_finetuned") # tokenizer.save_pretrained("./t5_sql_finetuned") # # Step 9: Test the Model test_question = "Find all orders with product_id greater than 5." input_text = f"Question: {test_question}" inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True) outputs = model.generate(**inputs, max_length=128) generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) print("Generated SQL:", generated_sql)