In [1]:
import json
import sqlparse
import pickle as pkl
dataset_names = ['academic', 'atis', 'advising', 'geography', 'imdb', 'restaurants', 'scholar', 'yelp']

# these datasets are small, so we use the full set. 
new_split_defined = {'restaurants', 'academic', 'imdb', 'yelp'} 

In [2]:
# loading the original datasets from the paper:
# Improving Text-to-SQL Evaluation Methodology

# a dataset is a list of dictionaries
# in the original dictionary, each datapoint might consist of several natural language sentences or SQL
orig_datasets = []
for dataset_name in dataset_names:
 orig_dataset = json.load(open('text2sql-data/data/%s.json' % dataset_name))
 for idx, d in enumerate(orig_dataset):
 
 d['orig_id'] = (dataset_name, idx)
 
 # fixing annotations here
 
 # change "company_name" to producer name, otherwise there is no variable to replace
 if dataset_name == 'imdb' and idx == 27:
 d['sql'][0] = 'SELECT MOVIEalias0.TITLE FROM COMPANY AS COMPANYalias0 , COPYRIGHT AS COPYRIGHTalias0 , MOVIE AS MOVIEalias0 WHERE COMPANYalias0.NAME = "producer_name0" AND COPYRIGHTalias0.CID = COMPANYalias0.ID AND MOVIEalias0.MID = COPYRIGHTalias0.MSID AND MOVIEalias0.RELEASE_YEAR > movie_release_year0 ;'
 
 # removing the extra space surrounding the variable actor_name0
 if dataset_name == 'imdb' and idx == 78:
 d['sql'][0] = 'SELECT MAX( DERIVED_TABLEalias0.DERIVED_FIELDalias0 ) FROM ( SELECT COUNT( DISTINCT ( MOVIEalias0.TITLE ) ) AS DERIVED_FIELDalias0 FROM ACTOR AS ACTORalias0 , CAST AS CASTalias0 , MOVIE AS MOVIEalias0 WHERE ACTORalias0.NAME = "actor_name0" AND CASTalias0.AID = ACTORalias0.AID AND MOVIEalias0.MID = CASTalias0.MSID GROUP BY MOVIEalias0.RELEASE_YEAR ) AS DERIVED_TABLEalias0 ;'
 
 # there was a scoping error; changed AUTHORalias1 to AUTHORalias0, PUBLICATIONalias1 to PUBLICATIONalias0
 if dataset_name == 'academic' and idx == 182:
 d['sql'][0] = 'SELECT DERIVED_FIELDalias0 FROM ( SELECT AUTHORalias0.NAME AS DERIVED_FIELDalias0 , COUNT( DISTINCT ( PUBLICATIONalias0.TITLE ) ) AS DERIVED_FIELDalias1 FROM AUTHOR AS AUTHORalias0 , CONFERENCE AS CONFERENCEalias0 , PUBLICATION AS PUBLICATIONalias0 , WRITES AS WRITESalias0 WHERE CONFERENCEalias0.NAME = "conference_name0" AND PUBLICATIONalias0.CID = CONFERENCEalias0.CID AND WRITESalias0.AID = AUTHORalias0.AID AND WRITESalias0.PID = PUBLICATIONalias0.PID GROUP BY AUTHORalias0.NAME ) AS DERIVED_TABLEalias0 , ( SELECT AUTHORalias1.NAME AS DERIVED_FIELDalias2 , COUNT( DISTINCT ( PUBLICATIONalias1.TITLE ) ) AS DERIVED_FIELDalias3 FROM AUTHOR AS AUTHORalias1 , CONFERENCE AS CONFERENCEalias1 , PUBLICATION AS PUBLICATIONalias1 , WRITES AS WRITESalias1 WHERE CONFERENCEalias1.NAME = "conference_name1" AND PUBLICATIONalias1.CID = CONFERENCEalias1.CID AND WRITESalias1.AID = AUTHORalias1.AID AND WRITESalias1.PID = PUBLICATIONalias1.PID GROUP BY AUTHORalias1.NAME ) AS DERIVED_TABLEalias1 WHERE DERIVED_TABLEalias0.DERIVED_FIELDalias1 > DERIVED_TABLEalias1.DERIVED_FIELDalias3 AND DERIVED_TABLEalias1.DERIVED_FIELDalias2 = DERIVED_TABLEalias0.DERIVED_FIELDalias0 ;'
 
 # wrong number of arguments to function COUNT(), change from "," to "||" for sqlite3 to recognize and execute
 if dataset_name == 'advising' and idx == 107:
 d['sql'][0] = 'SELECT COUNT( DISTINCT COURSEalias1.DEPARTMENT || COURSEalias0.NUMBER ) FROM COURSE AS COURSEalias0 , COURSE AS COURSEalias1 , COURSE_PREREQUISITE AS COURSE_PREREQUISITEalias0 , STUDENT_RECORD AS STUDENT_RECORDalias0 WHERE COURSEalias0.COURSE_ID = COURSE_PREREQUISITEalias0.PRE_COURSE_ID AND COURSEalias1.COURSE_ID = COURSE_PREREQUISITEalias0.COURSE_ID AND COURSEalias1.DEPARTMENT = "department0" AND COURSEalias1.NUMBER = number0 AND STUDENT_RECORDalias0.COURSE_ID = COURSEalias0.COURSE_ID AND STUDENT_RECORDalias0.STUDENT_ID = 1 ;'
 
 # there was not example given for level1 and hence replacing variable with values leads to errors
 if dataset_name == 'advising' and idx == 132:
 d['variables'][0]['example'] = '300'
 
 # cannot use count and order without group by; added grouping by actor_id
 if dataset_name == 'imdb' and idx == 79:
 d['sql'][0] = 'SELECT ACTORalias0.NAME FROM ACTOR AS ACTORalias0 , CAST AS CASTalias0 , MOVIE AS MOVIEalias0 WHERE CASTalias0.AID = ACTORalias0.AID AND MOVIEalias0.MID = CASTalias0.MSID GROUP BY ACTORalias0.AID ORDER BY COUNT( DISTINCT ( MOVIEalias0.TITLE ) ) DESC LIMIT 1 ;'
 
 # cannot use count and order without group by; added grouping by actor_id
 if dataset_name == 'imdb' and idx == 80:
 d['sql'][0] = 'SELECT ACTORalias0.NAME FROM ACTOR AS ACTORalias0 , CAST AS CASTalias0 , DIRECTED_BY AS DIRECTED_BYalias0 , DIRECTOR AS DIRECTORalias0 , MOVIE AS MOVIEalias0 WHERE CASTalias0.AID = ACTORalias0.AID AND DIRECTORalias0.DID = DIRECTED_BYalias0.DID AND MOVIEalias0.MID = CASTalias0.MSID AND MOVIEalias0.MID = DIRECTED_BYalias0.MSID GROUP BY ACTORalias0.AID ORDER BY COUNT( DISTINCT ( MOVIEalias0.TITLE ) ) DESC LIMIT 1 ;'
 
 # table has "u" in the neighborhood spelling.
 n_before, n_after = 'NEIGHBORHOOD', 'NEIGHBOURHOOD'
 if dataset_name == 'yelp':
 d['sql'][0] = d['sql'][0].replace(n_before, n_after)
 
 if dataset_name == 'yelp' and idx == 42:
 d['sql'][0] = 'SELECT NEIGHBOURHOODalias0.NEIGHBOURHOOD_NAME FROM BUSINESS AS BUSINESSalias0 , NEIGHBOURHOOD AS NEIGHBOURHOODalias0 , REVIEW AS REVIEWalias0 , USER AS USERalias0 WHERE NEIGHBOURHOODalias0.BUSINESS_ID = BUSINESSalias0.BUSINESS_ID AND REVIEWalias0.BUSINESS_ID = BUSINESSalias0.BUSINESS_ID AND USERalias0.NAME = "user_name0" AND USERalias0.USER_ID = REVIEWalias0.USER_ID ;'

 orig_datasets.extend(orig_dataset)

In [3]:
# we create the new testset here
new_testset = []
for d in orig_datasets:
 orig_id = d['orig_id']
 db_id, idx = orig_id
 
 # we only incorporate the test split if the dataset is large enough
 # otherwise we incorporate the entire dataset
 if d['query-split'] != 'test' and db_id not in new_split_defined:
 continue
 sql = d['sql'][0]
 instance_variables = d['variables']
 instance_name2examples = {d['name']: d['example'] for d in instance_variables}
 
 # we create a new datapoint for each natural language query
 for sentence in d['sentences']:
 new_datapoint = {
 'text': sentence['text'],
 'query': sql,
 'variables': instance_variables,
 'orig_id': orig_id,
 'db_id': db_id,
 'db_path': 'database/{db_id}/{db_id}.sqlite'.format(db_id=db_id)
 }
 new_testset.append(new_datapoint)
print('There are %d datapoints in the new testset' % len(new_testset))

There are 3509 datapoints in the new testset


In [4]:
import re

# this block implements a function that extract variable names from text and sql
# later we use it to ensure that every variable is replaced

variable_pattern = re.compile('^[a-z_]+[0-9]+$')

def extract_variable_names(t):
 tokens = t.replace('"', '').replace('%', '').split(' ')
 var_names = {v for v in tokens if variable_pattern.match(v) and 'alias' not in v}
 return var_names

test = False
if test:
 sql = 'SELECT BUSINESSalias0.NAME FROM BUSINESS AS BUSINESSalias0 , REVIEW AS REVIEWalias0 WHERE REVIEWalias0.BUSINESS_ID = BUSINESSalias0.BUSINESS_ID AND REVIEWalias0.MONTH = "review_month0" GROUP BY BUSINESSalias0.NAME ORDER BY COUNT( DISTINCT ( REVIEWalias0.TEXT ) ) DESC LIMIT 1 ;'
 print(extract_variable_names(sql))
 text = 'return me the homepage of journal_name0 .'
 print(extract_variable_names(text))

In [5]:
# this block removes extra space surrounding variable names
def remove_extra_space_around_variable(t):
 var_names = extract_variable_names(t)
 result = str(t)
 for v in var_names:
 result = result.replace('" ' + v + ' "', v)
 return result

In [6]:
problematic = set()

for datapoint in new_testset:
 orig_id = datapoint['orig_id']
 
 # remove extra whitespace surrounding the text
 datapoint['text'] = remove_extra_space_around_variable(datapoint['text'])
 
 # there should not be extra whitespace surrounding the sql variables
 if datapoint['query'] != remove_extra_space_around_variable(datapoint['query']):
 problematic.add(orig_id)

 text_vars = extract_variable_names(datapoint['text'])
 sql_vars = extract_variable_names(datapoint['query'])
 
 instance_variables = {d['name']: d for d in datapoint['variables']}
 
 # we ensure that all the variables in the sql query and the text can be replaced
 # by some variable in the variable dictionary
 if len(text_vars - instance_variables.keys()) != 0 or len(sql_vars - instance_variables.keys()):
 problematic.add(orig_id)
 
 # replace the variables with the examples in the variable dictionary
 for text_var in text_vars:
 datapoint['text'] = datapoint['text'].replace(text_var, instance_variables[text_var]['example'])
 
 for sql_var in sql_vars:
 datapoint['query'] = datapoint['query'].replace(sql_var, instance_variables[sql_var]['example'])

# we can trace back which datapoints do not satisfy the assumption,
# then go back and fix it manually
print(problematic)

set()


In [7]:
from pprint import pprint

pprint(new_testset[:5])

[{'db_id': 'academic',
 'db_path': 'database/academic/academic.sqlite',
 'orig_id': ('academic', 0),
 'query': 'SELECT JOURNALalias0.HOMEPAGE FROM JOURNAL AS JOURNALalias0 WHERE '
 'JOURNALalias0.NAME = "PVLDB" ;',
 'text': 'return me the homepage of PVLDB .',
 'variables': [{'example': 'PVLDB',
 'location': 'both',
 'name': 'journal_name0',
 'type': 'journal_name'}]},
 {'db_id': 'academic',
 'db_path': 'database/academic/academic.sqlite',
 'orig_id': ('academic', 1),
 'query': 'SELECT AUTHORalias0.HOMEPAGE FROM AUTHOR AS AUTHORalias0 WHERE '
 'AUTHORalias0.NAME = "H. V. Jagadish" ;',
 'text': 'return me the homepage of H. V. Jagadish .',
 'variables': [{'example': 'H. V. Jagadish',
 'location': 'both',
 'name': 'author_name0',
 'type': 'author_name'}]},
 {'db_id': 'academic',
 'db_path': 'database/academic/academic.sqlite',
 'orig_id': ('academic', 2),
 'query': 'SELECT PUBLICATIONalias0.ABSTRACT FROM PUBLICATION AS '
 'PUBLICATIONalias0 WHERE PUBLICATIONalias0.TITLE = "Making databas