Fine-Tuned Google T5 Model for Text to SQL Translation

A fine-tuned version of the Google T5 model, trained for the task of translating natural language queries into SQL statements.

Model Details

Training Parameters

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False,
)

Usage

import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the tokenizer and model
model_path = 'juanfra218/text2sql'
tokenizer = T5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
model.to(device)  

# Function to generate SQL queries
def generate_sql(prompt, schema):
    input_text = "translate English to SQL: " + prompt + " " + schema
    inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length")

    inputs = {key: value.to(device) for key, value in inputs.items()}

    max_output_length = 1024
    outputs = model.generate(**inputs, max_length=max_output_length)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Interactive loop
print("Enter 'quit' to exit.")
while True:
    prompt = input("Insert prompt: ")
    schema = input("Insert schema: ")
    if prompt.lower() == 'quit':
        break

    sql_query = generate_sql(prompt, schema)
    print(f"Generated SQL query: {sql_query}")
    print()

Files

  • optimizer.pt: State of the optimizer.
  • training_args.bin: Training arguments and hyperparameters.
  • tokenizer.json: Tokenizer vocabulary and settings.
  • spiece.model: SentencePiece model file.
  • special_tokens_map.json: Special tokens mapping.
  • tokenizer_config.json: Tokenizer configuration settings.
  • model.safetensors: Trained model weights.
  • generation_config.json: Configuration for text generation.
  • config.json: Model architecture configuration.
  • test_results.csv: Results on the testing set, contains: prompt, context, true_answer, predicted_answer, exact_match
Downloads last month
15
Safetensors
Model size
223M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for juanfra218/text2sql

Base model

google-t5/t5-base
Finetuned
(430)
this model

Datasets used to train juanfra218/text2sql