Manoj Kumar commited on
Commit
7c39f2c
·
1 Parent(s): 2621d33

updated code

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. db.py +55 -0
  3. t5.py +21 -0
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: red
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.11.0
8
- app_file: database.py
9
  pinned: false
10
  python: 3.9
11
  ---
 
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.11.0
8
+ app_file: t5.py
9
  pinned: false
10
  python: 3.9
11
  ---
db.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+
3
+ # Example schema
4
+ schema = {
5
+ "products": {
6
+ "columns": ["product_id", "name", "price", "category_id"],
7
+ "relations": "category_id -> categories.id",
8
+ },
9
+ "categories": {
10
+ "columns": ["id", "category_name"],
11
+ "relations": None,
12
+ },
13
+ "orders": {
14
+ "columns": ["order_id", "customer_name", "product_id", "order_date"],
15
+ "relations": "product_id -> products.product_id",
16
+ },
17
+ }
18
+
19
+ # Step 1: Generate context dynamically from schema
20
+ def generate_context(schema):
21
+ context_lines = []
22
+ for table, details in schema.items():
23
+ # List table columns
24
+ columns = ", ".join(details["columns"])
25
+ context_lines.append(f"The {table} table has the following columns: {columns}.")
26
+
27
+ # Add relationships if present
28
+ if details["relations"]:
29
+ context_lines.append(f"The {table} table has the following relationship: {details['relations']}.")
30
+
31
+ return "\n".join(context_lines)
32
+
33
+ # Generate schema context
34
+ schema_context = generate_context(schema)
35
+
36
+ # Step 2: Load the T5-base-text-to-sql model
37
+ model_name = "mrm8488/t5-base-finetuned-wikiSQL" # A model fine-tuned for SQL generation
38
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
39
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
40
+
41
+ # Step 3: Define a natural language query
42
+ user_query = "List all orders where the product price is greater than 50."
43
+
44
+ # Prepare the input for the model
45
+ # Adjust the prompt to focus on SQL generation
46
+ input_text = f"Convert the following question into an SQL query:\nSchema:\n{schema_context}\n\nQuestion:\n{user_query}"
47
+ inputs = tokenizer.encode(input_text, return_tensors="pt")
48
+
49
+ # Step 4: Generate SQL query
50
+ outputs = model.generate(inputs, max_length=128, num_beams=4, early_stopping=True)
51
+ generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
+
53
+ # Step 5: Display the result
54
+ print("User Query:", user_query)
55
+ print("Generated SQL Query:", generated_sql)
t5.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelWithLMHead, AutoTokenizer
2
+
3
+ tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
4
+ model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
5
+
6
+ def get_sql(query):
7
+ input_text = "translate English to SQL: %s </s>" % query
8
+ features = tokenizer([input_text], return_tensors='pt')
9
+
10
+ output = model.generate(input_ids=features['input_ids'],
11
+ attention_mask=features['attention_mask'])
12
+
13
+ return tokenizer.decode(output[0])
14
+
15
+ query = "How many models were finetuned using BERT as base model?"
16
+
17
+ res = get_sql(query)
18
+
19
+ print(res)
20
+
21
+ # output: 'SELECT COUNT Model fine tuned FROM table WHERE Base model = BERT'