|
--- |
|
license: mit |
|
datasets: |
|
- b-mc2/sql-create-context |
|
- gretelai/synthetic_text_to_sql |
|
language: |
|
- en |
|
base_model: google-t5/t5-base |
|
metrics: |
|
- exact_match |
|
model-index: |
|
- name: juanfra218/text2sql |
|
results: |
|
- task: |
|
type: text-to-sql |
|
metrics: |
|
- name: exact_match |
|
type: exact_match |
|
value: 0.4326836917562724 |
|
- name: bleu |
|
type: bleu |
|
value: 0.6687 |
|
tags: |
|
- sql |
|
library_name: transformers |
|
--- |
|
|
|
# Fine-Tuned Google T5 Model for Text to SQL Translation |
|
|
|
A fine-tuned version of the Google T5 model, trained for the task of translating natural language queries into SQL statements. |
|
|
|
## Model Details |
|
|
|
- **Architecture**: Google T5 Base (Text-to-Text Transfer Transformer) |
|
- **Task**: Text to SQL Translation |
|
- **Fine-Tuning Datasets**: |
|
- [sql-create-context Dataset](https://huggingface.co/datasets/b-mc2/sql-create-context) |
|
- [Synthetic-Text-To-SQL Dataset](https://huggingface.co/datasets/gretelai/synthetic-text-to-sql) |
|
|
|
## Training Parameters |
|
|
|
``` |
|
training_args = Seq2SeqTrainingArguments( |
|
output_dir="./results", |
|
evaluation_strategy="epoch", |
|
learning_rate=2e-5, |
|
per_device_train_batch_size=8, |
|
per_device_eval_batch_size=8, |
|
weight_decay=0.01, |
|
save_total_limit=3, |
|
num_train_epochs=3, |
|
predict_with_generate=True, |
|
fp16=True, |
|
push_to_hub=False, |
|
) |
|
``` |
|
|
|
## Usage |
|
|
|
``` |
|
import torch |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
# Load the tokenizer and model |
|
model_path = 'juanfra218/text2sql' |
|
tokenizer = T5Tokenizer.from_pretrained(model_path) |
|
model = T5ForConditionalGeneration.from_pretrained(model_path) |
|
model.to(device) |
|
|
|
# Function to generate SQL queries |
|
def generate_sql(prompt, schema): |
|
input_text = "translate English to SQL: " + prompt + " " + schema |
|
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length") |
|
|
|
inputs = {key: value.to(device) for key, value in inputs.items()} |
|
|
|
max_output_length = 1024 |
|
outputs = model.generate(**inputs, max_length=max_output_length) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
# Interactive loop |
|
print("Enter 'quit' to exit.") |
|
while True: |
|
prompt = input("Insert prompt: ") |
|
schema = input("Insert schema: ") |
|
if prompt.lower() == 'quit': |
|
break |
|
|
|
sql_query = generate_sql(prompt, schema) |
|
print(f"Generated SQL query: {sql_query}") |
|
print() |
|
``` |
|
|
|
## Files |
|
|
|
- `optimizer.pt`: State of the optimizer. |
|
- `training_args.bin`: Training arguments and hyperparameters. |
|
- `tokenizer.json`: Tokenizer vocabulary and settings. |
|
- `spiece.model`: SentencePiece model file. |
|
- `special_tokens_map.json`: Special tokens mapping. |
|
- `tokenizer_config.json`: Tokenizer configuration settings. |
|
- `model.safetensors`: Trained model weights. |
|
- `generation_config.json`: Configuration for text generation. |
|
- `config.json`: Model architecture configuration. |
|
- `test_results.csv`: Results on the testing set, contains: prompt, context, true_answer, predicted_answer, exact_match |