{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import json\n", "import sqlparse\n", "import pickle as pkl\n", "dataset_names = ['academic', 'atis', 'advising', 'geography', 'imdb', 'restaurants', 'scholar', 'yelp']\n", "\n", "# these datasets are small, so we use the full set. \n", "new_split_defined = {'restaurants', 'academic', 'imdb', 'yelp'} " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# loading the original datasets from the paper:\n", "# Improving Text-to-SQL Evaluation Methodology\n", "\n", "# a dataset is a list of dictionaries\n", "# in the original dictionary, each datapoint might consist of several natural language sentences or SQL\n", "orig_datasets = []\n", "for dataset_name in dataset_names:\n", " orig_dataset = json.load(open('text2sql-data/data/%s.json' % dataset_name))\n", " for idx, d in enumerate(orig_dataset):\n", " \n", " d['orig_id'] = (dataset_name, idx)\n", " \n", " # fixing annotations here\n", " \n", " # change \"company_name\" to producer name, otherwise there is no variable to replace\n", " if dataset_name == 'imdb' and idx == 27:\n", " d['sql'][0] = 'SELECT MOVIEalias0.TITLE FROM COMPANY AS COMPANYalias0 , COPYRIGHT AS COPYRIGHTalias0 , MOVIE AS MOVIEalias0 WHERE COMPANYalias0.NAME = \"producer_name0\" AND COPYRIGHTalias0.CID = COMPANYalias0.ID AND MOVIEalias0.MID = COPYRIGHTalias0.MSID AND MOVIEalias0.RELEASE_YEAR > movie_release_year0 ;'\n", " \n", " # removing the extra space surrounding the variable actor_name0\n", " if dataset_name == 'imdb' and idx == 78:\n", " d['sql'][0] = 'SELECT MAX( DERIVED_TABLEalias0.DERIVED_FIELDalias0 ) FROM ( SELECT COUNT( DISTINCT ( MOVIEalias0.TITLE ) ) AS DERIVED_FIELDalias0 FROM ACTOR AS ACTORalias0 , CAST AS CASTalias0 , MOVIE AS MOVIEalias0 WHERE ACTORalias0.NAME = \"actor_name0\" AND CASTalias0.AID = ACTORalias0.AID AND MOVIEalias0.MID = CASTalias0.MSID GROUP BY MOVIEalias0.RELEASE_YEAR ) AS DERIVED_TABLEalias0 ;'\n", " \n", " # there was a scoping error; changed AUTHORalias1 to AUTHORalias0, PUBLICATIONalias1 to PUBLICATIONalias0\n", " if dataset_name == 'academic' and idx == 182:\n", " d['sql'][0] = 'SELECT DERIVED_FIELDalias0 FROM ( SELECT AUTHORalias0.NAME AS DERIVED_FIELDalias0 , COUNT( DISTINCT ( PUBLICATIONalias0.TITLE ) ) AS DERIVED_FIELDalias1 FROM AUTHOR AS AUTHORalias0 , CONFERENCE AS CONFERENCEalias0 , PUBLICATION AS PUBLICATIONalias0 , WRITES AS WRITESalias0 WHERE CONFERENCEalias0.NAME = \"conference_name0\" AND PUBLICATIONalias0.CID = CONFERENCEalias0.CID AND WRITESalias0.AID = AUTHORalias0.AID AND WRITESalias0.PID = PUBLICATIONalias0.PID GROUP BY AUTHORalias0.NAME ) AS DERIVED_TABLEalias0 , ( SELECT AUTHORalias1.NAME AS DERIVED_FIELDalias2 , COUNT( DISTINCT ( PUBLICATIONalias1.TITLE ) ) AS DERIVED_FIELDalias3 FROM AUTHOR AS AUTHORalias1 , CONFERENCE AS CONFERENCEalias1 , PUBLICATION AS PUBLICATIONalias1 , WRITES AS WRITESalias1 WHERE CONFERENCEalias1.NAME = \"conference_name1\" AND PUBLICATIONalias1.CID = CONFERENCEalias1.CID AND WRITESalias1.AID = AUTHORalias1.AID AND WRITESalias1.PID = PUBLICATIONalias1.PID GROUP BY AUTHORalias1.NAME ) AS DERIVED_TABLEalias1 WHERE DERIVED_TABLEalias0.DERIVED_FIELDalias1 > DERIVED_TABLEalias1.DERIVED_FIELDalias3 AND DERIVED_TABLEalias1.DERIVED_FIELDalias2 = DERIVED_TABLEalias0.DERIVED_FIELDalias0 ;'\n", " \n", " # wrong number of arguments to function COUNT(), change from \",\" to \"||\" for sqlite3 to recognize and execute\n", " if dataset_name == 'advising' and idx == 107:\n", " d['sql'][0] = 'SELECT COUNT( DISTINCT COURSEalias1.DEPARTMENT || COURSEalias0.NUMBER ) FROM COURSE AS COURSEalias0 , COURSE AS COURSEalias1 , COURSE_PREREQUISITE AS COURSE_PREREQUISITEalias0 , STUDENT_RECORD AS STUDENT_RECORDalias0 WHERE COURSEalias0.COURSE_ID = COURSE_PREREQUISITEalias0.PRE_COURSE_ID AND COURSEalias1.COURSE_ID = COURSE_PREREQUISITEalias0.COURSE_ID AND COURSEalias1.DEPARTMENT = \"department0\" AND COURSEalias1.NUMBER = number0 AND STUDENT_RECORDalias0.COURSE_ID = COURSEalias0.COURSE_ID AND STUDENT_RECORDalias0.STUDENT_ID = 1 ;'\n", " \n", " # there was not example given for level1 and hence replacing variable with values leads to errors\n", " if dataset_name == 'advising' and idx == 132:\n", " d['variables'][0]['example'] = '300'\n", " \n", " # cannot use count and order without group by; added grouping by actor_id\n", " if dataset_name == 'imdb' and idx == 79:\n", " d['sql'][0] = 'SELECT ACTORalias0.NAME FROM ACTOR AS ACTORalias0 , CAST AS CASTalias0 , MOVIE AS MOVIEalias0 WHERE CASTalias0.AID = ACTORalias0.AID AND MOVIEalias0.MID = CASTalias0.MSID GROUP BY ACTORalias0.AID ORDER BY COUNT( DISTINCT ( MOVIEalias0.TITLE ) ) DESC LIMIT 1 ;'\n", " \n", " # cannot use count and order without group by; added grouping by actor_id\n", " if dataset_name == 'imdb' and idx == 80:\n", " d['sql'][0] = 'SELECT ACTORalias0.NAME FROM ACTOR AS ACTORalias0 , CAST AS CASTalias0 , DIRECTED_BY AS DIRECTED_BYalias0 , DIRECTOR AS DIRECTORalias0 , MOVIE AS MOVIEalias0 WHERE CASTalias0.AID = ACTORalias0.AID AND DIRECTORalias0.DID = DIRECTED_BYalias0.DID AND MOVIEalias0.MID = CASTalias0.MSID AND MOVIEalias0.MID = DIRECTED_BYalias0.MSID GROUP BY ACTORalias0.AID ORDER BY COUNT( DISTINCT ( MOVIEalias0.TITLE ) ) DESC LIMIT 1 ;'\n", " \n", " # table has \"u\" in the neighborhood spelling.\n", " n_before, n_after = 'NEIGHBORHOOD', 'NEIGHBOURHOOD'\n", " if dataset_name == 'yelp':\n", " d['sql'][0] = d['sql'][0].replace(n_before, n_after)\n", " \n", " if dataset_name == 'yelp' and idx == 42:\n", " d['sql'][0] = 'SELECT NEIGHBOURHOODalias0.NEIGHBOURHOOD_NAME FROM BUSINESS AS BUSINESSalias0 , NEIGHBOURHOOD AS NEIGHBOURHOODalias0 , REVIEW AS REVIEWalias0 , USER AS USERalias0 WHERE NEIGHBOURHOODalias0.BUSINESS_ID = BUSINESSalias0.BUSINESS_ID AND REVIEWalias0.BUSINESS_ID = BUSINESSalias0.BUSINESS_ID AND USERalias0.NAME = \"user_name0\" AND USERalias0.USER_ID = REVIEWalias0.USER_ID ;'\n", "\n", " orig_datasets.extend(orig_dataset)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "There are 3509 datapoints in the new testset\n" ] } ], "source": [ "# we create the new testset here\n", "new_testset = []\n", "for d in orig_datasets:\n", " orig_id = d['orig_id']\n", " db_id, idx = orig_id\n", " \n", " # we only incorporate the test split if the dataset is large enough\n", " # otherwise we incorporate the entire dataset\n", " if d['query-split'] != 'test' and db_id not in new_split_defined:\n", " continue\n", " sql = d['sql'][0]\n", " instance_variables = d['variables']\n", " instance_name2examples = {d['name']: d['example'] for d in instance_variables}\n", " \n", " # we create a new datapoint for each natural language query\n", " for sentence in d['sentences']:\n", " new_datapoint = {\n", " 'text': sentence['text'],\n", " 'query': sql,\n", " 'variables': instance_variables,\n", " 'orig_id': orig_id,\n", " 'db_id': db_id,\n", " 'db_path': 'database/{db_id}/{db_id}.sqlite'.format(db_id=db_id)\n", " }\n", " new_testset.append(new_datapoint)\n", "print('There are %d datapoints in the new testset' % len(new_testset))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import re\n", "\n", "# this block implements a function that extract variable names from text and sql\n", "# later we use it to ensure that every variable is replaced\n", "\n", "variable_pattern = re.compile('^[a-z_]+[0-9]+$')\n", "\n", "def extract_variable_names(t):\n", " tokens = t.replace('\"', '').replace('%', '').split(' ')\n", " var_names = {v for v in tokens if variable_pattern.match(v) and 'alias' not in v}\n", " return var_names\n", "\n", "test = False\n", "if test:\n", " sql = 'SELECT BUSINESSalias0.NAME FROM BUSINESS AS BUSINESSalias0 , REVIEW AS REVIEWalias0 WHERE REVIEWalias0.BUSINESS_ID = BUSINESSalias0.BUSINESS_ID AND REVIEWalias0.MONTH = \"review_month0\" GROUP BY BUSINESSalias0.NAME ORDER BY COUNT( DISTINCT ( REVIEWalias0.TEXT ) ) DESC LIMIT 1 ;'\n", " print(extract_variable_names(sql))\n", " text = 'return me the homepage of journal_name0 .'\n", " print(extract_variable_names(text))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# this block removes extra space surrounding variable names\n", "def remove_extra_space_around_variable(t):\n", " var_names = extract_variable_names(t)\n", " result = str(t)\n", " for v in var_names:\n", " result = result.replace('\" ' + v + ' \"', v)\n", " return result" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "set()\n" ] } ], "source": [ "problematic = set()\n", "\n", "for datapoint in new_testset:\n", " orig_id = datapoint['orig_id']\n", " \n", " # remove extra whitespace surrounding the text\n", " datapoint['text'] = remove_extra_space_around_variable(datapoint['text'])\n", " \n", " # there should not be extra whitespace surrounding the sql variables\n", " if datapoint['query'] != remove_extra_space_around_variable(datapoint['query']):\n", " problematic.add(orig_id)\n", "\n", " text_vars = extract_variable_names(datapoint['text'])\n", " sql_vars = extract_variable_names(datapoint['query'])\n", " \n", " instance_variables = {d['name']: d for d in datapoint['variables']}\n", " \n", " # we ensure that all the variables in the sql query and the text can be replaced\n", " # by some variable in the variable dictionary\n", " if len(text_vars - instance_variables.keys()) != 0 or len(sql_vars - instance_variables.keys()):\n", " problematic.add(orig_id)\n", " \n", " # replace the variables with the examples in the variable dictionary\n", " for text_var in text_vars:\n", " datapoint['text'] = datapoint['text'].replace(text_var, instance_variables[text_var]['example'])\n", " \n", " for sql_var in sql_vars:\n", " datapoint['query'] = datapoint['query'].replace(sql_var, instance_variables[sql_var]['example'])\n", "\n", "# we can trace back which datapoints do not satisfy the assumption,\n", "# then go back and fix it manually\n", "print(problematic)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[{'db_id': 'academic',\n", " 'db_path': 'database/academic/academic.sqlite',\n", " 'orig_id': ('academic', 0),\n", " 'query': 'SELECT JOURNALalias0.HOMEPAGE FROM JOURNAL AS JOURNALalias0 WHERE '\n", " 'JOURNALalias0.NAME = \"PVLDB\" ;',\n", " 'text': 'return me the homepage of PVLDB .',\n", " 'variables': [{'example': 'PVLDB',\n", " 'location': 'both',\n", " 'name': 'journal_name0',\n", " 'type': 'journal_name'}]},\n", " {'db_id': 'academic',\n", " 'db_path': 'database/academic/academic.sqlite',\n", " 'orig_id': ('academic', 1),\n", " 'query': 'SELECT AUTHORalias0.HOMEPAGE FROM AUTHOR AS AUTHORalias0 WHERE '\n", " 'AUTHORalias0.NAME = \"H. V. Jagadish\" ;',\n", " 'text': 'return me the homepage of H. V. Jagadish .',\n", " 'variables': [{'example': 'H. V. Jagadish',\n", " 'location': 'both',\n", " 'name': 'author_name0',\n", " 'type': 'author_name'}]},\n", " {'db_id': 'academic',\n", " 'db_path': 'database/academic/academic.sqlite',\n", " 'orig_id': ('academic', 2),\n", " 'query': 'SELECT PUBLICATIONalias0.ABSTRACT FROM PUBLICATION AS '\n", " 'PUBLICATIONalias0 WHERE PUBLICATIONalias0.TITLE = \"Making database '\n", " 'systems usable\" ;',\n", " 'text': 'return me the abstract of Making database systems usable .',\n", " 'variables': [{'example': 'Making database systems usable',\n", " 'location': 'both',\n", " 'name': 'publication_title0',\n", " 'type': 'publication_title'}]},\n", " {'db_id': 'academic',\n", " 'db_path': 'database/academic/academic.sqlite',\n", " 'orig_id': ('academic', 3),\n", " 'query': 'SELECT PUBLICATIONalias0.YEAR FROM PUBLICATION AS '\n", " 'PUBLICATIONalias0 WHERE PUBLICATIONalias0.TITLE = \"Making database '\n", " 'systems usable\" ;',\n", " 'text': 'return me the year of Making database systems usable',\n", " 'variables': [{'example': 'Making database systems usable',\n", " 'location': 'both',\n", " 'name': 'publication_title0',\n", " 'type': 'publication_title'}]},\n", " {'db_id': 'academic',\n", " 'db_path': 'database/academic/academic.sqlite',\n", " 'orig_id': ('academic', 3),\n", " 'query': 'SELECT PUBLICATIONalias0.YEAR FROM PUBLICATION AS '\n", " 'PUBLICATIONalias0 WHERE PUBLICATIONalias0.TITLE = \"Making database '\n", " 'systems usable\" ;',\n", " 'text': 'return me the year of Making database systems usable .',\n", " 'variables': [{'example': 'Making database systems usable',\n", " 'location': 'both',\n", " 'name': 'publication_title0',\n", " 'type': 'publication_title'}]}]\n" ] } ], "source": [ "from pprint import pprint\n", "\n", "pprint(new_testset[:5])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 2 }