from transformers import T5ForConditionalGeneration, T5Tokenizer import torch import re # Load the trained model and tokenizer model = T5ForConditionalGeneration.from_pretrained("./t5_sql_finetuned") tokenizer = T5Tokenizer.from_pretrained("./t5_sql_finetuned") # Define a simple function to check if the question is schema-related or SQL-related def is_schema_question(question: str): schema_keywords = ["columns", "tables", "structure", "schema", "relations", "fields"] return any(keyword in question.lower() for keyword in schema_keywords) # Helper function to extract table name from the question def extract_table_name(question: str): # Regex pattern to find table names, assuming table names are capitalized or match a known pattern table_name_match = re.search(r'for (\w+)|in (\w+)|from (\w+)', question) if table_name_match: # Return the matched table name (first capturing group) return table_name_match.group(1) or table_name_match.group(2) or table_name_match.group(3) # If no table name is detected, return None return None # Define a function to handle SQL generation def generate_sql(question: str, schema: dict, model, tokenizer, device): # Preprocess the question for SQL generation (e.g., reformat) # Example question: "What is the price of the product with ID 123?" # Here we use the model to generate SQL query inputs = tokenizer(question, return_tensors="pt") input_ids = inputs.input_ids.to(device) with torch.no_grad(): generated_ids = model.generate(input_ids, max_length=128) # Decode the SQL query generated by the model sql_query = tokenizer.decode(generated_ids[0], skip_special_tokens=True) return sql_query # Define a function to handle schema-related questions def handle_schema_question(question: str, schema: dict): # Here you handle questions about the schema (tables, columns, relations) # Example schema-related question: "What columns does the products table have?" question = question.lower() # Check if the question asks about columns if "columns" in question or "fields" in question: table_name = extract_table_name(question) if table_name: if table_name in schema: return schema[table_name]["columns"] else: return f"Table '{table_name}' not found in the schema." # Check if the question asks about relations elif "relations" in question or "relationships" in question: table_name = extract_table_name(question) if table_name: if table_name in schema: return schema[table_name]["relations"] else: return f"Table '{table_name}' not found in the schema." # Additional cases can be handled here (e.g., "Which tables are in the schema?") elif "tables" in question: return list(schema.keys()) # If the question is too vague or doesn't match the expected patterns return "Sorry, I couldn't understand your schema question. Could you rephrase?" # Example schema for your custom use case custom_schema = { "products": { "columns": ["product_id", "name", "price", "category_id"], "relations": "category_id -> categories.id", }, "categories": { "columns": ["id", "category_name"], "relations": None, }, "orders": { "columns": ["order_id", "user_id", "product_id", "order_date"], "relations": ["product_id -> products.product_id", "user_id -> users.user_id"], }, "users": { "columns": ["user_id", "first_name", "last_name", "email", "phone_number", "address"], "relations": None, } } def answer_question(question: str, schema: dict, model, tokenizer, device): # First, check if the question is about the schema or SQL if is_schema_question(question): # Handle schema-related questions response = handle_schema_question(question, schema) return f"Schema Information: {response}" else: # Generate an SQL query for data-related questions sql_query = generate_sql(question, schema, model, tokenizer, device) return f"Generated SQL Query: {sql_query}" # Example input questions question_1 = "What columns does the products table have?" question_2 = "What is the price of the product with product_id 123?" # Assuming you have loaded your model and tokenizer as `model` and `tokenizer` device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Handle schema question response_1 = answer_question(question_1, custom_schema, model, tokenizer, device) print(response_1) # This should give you the columns of the products table # Handle SQL query question response_2 = answer_question(question_2, custom_schema, model, tokenizer, device) print(response_2) # This should generate an SQL query for fetching the price