{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import sqlparse\n",
    "import pickle as pkl\n",
    "dataset_names = ['academic', 'atis', 'advising', 'geography', 'imdb', 'restaurants', 'scholar', 'yelp']\n",
    "\n",
    "# these datasets are small, so we use the full set. \n",
    "new_split_defined = {'restaurants', 'academic', 'imdb', 'yelp'} "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# loading the original datasets from the paper:\n",
    "# Improving Text-to-SQL Evaluation Methodology\n",
    "\n",
    "# a dataset is a list of dictionaries\n",
    "# in the original dictionary, each datapoint might consist of several natural language sentences or SQL\n",
    "orig_datasets = []\n",
    "for dataset_name in dataset_names:\n",
    "    orig_dataset = json.load(open('text2sql-data/data/%s.json' % dataset_name))\n",
    "    for idx, d in enumerate(orig_dataset):\n",
    "        \n",
    "        d['orig_id'] = (dataset_name, idx)\n",
    "        \n",
    "        # fixing annotations here\n",
    "        \n",
    "        # change \"company_name\" to producer name, otherwise there is no variable to replace\n",
    "        if dataset_name == 'imdb' and idx == 27:\n",
    "            d['sql'][0] = 'SELECT MOVIEalias0.TITLE FROM COMPANY AS COMPANYalias0 , COPYRIGHT AS COPYRIGHTalias0 , MOVIE AS MOVIEalias0 WHERE COMPANYalias0.NAME = \"producer_name0\" AND COPYRIGHTalias0.CID = COMPANYalias0.ID AND MOVIEalias0.MID = COPYRIGHTalias0.MSID AND MOVIEalias0.RELEASE_YEAR > movie_release_year0 ;'\n",
    "    \n",
    "        # removing the extra space surrounding the variable actor_name0\n",
    "        if dataset_name == 'imdb' and idx == 78:\n",
    "            d['sql'][0] = 'SELECT MAX( DERIVED_TABLEalias0.DERIVED_FIELDalias0 ) FROM ( SELECT COUNT( DISTINCT ( MOVIEalias0.TITLE ) ) AS DERIVED_FIELDalias0 FROM ACTOR AS ACTORalias0 , CAST AS CASTalias0 , MOVIE AS MOVIEalias0 WHERE ACTORalias0.NAME = \"actor_name0\" AND CASTalias0.AID = ACTORalias0.AID AND MOVIEalias0.MID = CASTalias0.MSID GROUP BY MOVIEalias0.RELEASE_YEAR ) AS DERIVED_TABLEalias0 ;'\n",
    "    \n",
    "        # there was a scoping error; changed AUTHORalias1 to AUTHORalias0, PUBLICATIONalias1 to PUBLICATIONalias0\n",
    "        if dataset_name == 'academic' and idx == 182:\n",
    "            d['sql'][0] = 'SELECT DERIVED_FIELDalias0 FROM ( SELECT AUTHORalias0.NAME AS DERIVED_FIELDalias0 , COUNT( DISTINCT ( PUBLICATIONalias0.TITLE ) ) AS DERIVED_FIELDalias1 FROM AUTHOR AS AUTHORalias0 , CONFERENCE AS CONFERENCEalias0 , PUBLICATION AS PUBLICATIONalias0 , WRITES AS WRITESalias0 WHERE CONFERENCEalias0.NAME = \"conference_name0\" AND PUBLICATIONalias0.CID = CONFERENCEalias0.CID AND WRITESalias0.AID = AUTHORalias0.AID AND WRITESalias0.PID = PUBLICATIONalias0.PID GROUP BY AUTHORalias0.NAME ) AS DERIVED_TABLEalias0 , ( SELECT AUTHORalias1.NAME AS DERIVED_FIELDalias2 , COUNT( DISTINCT ( PUBLICATIONalias1.TITLE ) ) AS DERIVED_FIELDalias3 FROM AUTHOR AS AUTHORalias1 , CONFERENCE AS CONFERENCEalias1 , PUBLICATION AS PUBLICATIONalias1 , WRITES AS WRITESalias1 WHERE CONFERENCEalias1.NAME = \"conference_name1\" AND PUBLICATIONalias1.CID = CONFERENCEalias1.CID AND WRITESalias1.AID = AUTHORalias1.AID AND WRITESalias1.PID = PUBLICATIONalias1.PID GROUP BY AUTHORalias1.NAME ) AS DERIVED_TABLEalias1 WHERE DERIVED_TABLEalias0.DERIVED_FIELDalias1 > DERIVED_TABLEalias1.DERIVED_FIELDalias3 AND DERIVED_TABLEalias1.DERIVED_FIELDalias2 = DERIVED_TABLEalias0.DERIVED_FIELDalias0 ;'\n",
    "        \n",
    "        # wrong number of arguments to function COUNT(), change from \",\" to \"||\" for sqlite3 to recognize and execute\n",
    "        if dataset_name == 'advising' and idx == 107:\n",
    "            d['sql'][0] = 'SELECT COUNT( DISTINCT COURSEalias1.DEPARTMENT || COURSEalias0.NUMBER ) FROM COURSE AS COURSEalias0 , COURSE AS COURSEalias1 , COURSE_PREREQUISITE AS COURSE_PREREQUISITEalias0 , STUDENT_RECORD AS STUDENT_RECORDalias0 WHERE COURSEalias0.COURSE_ID = COURSE_PREREQUISITEalias0.PRE_COURSE_ID AND COURSEalias1.COURSE_ID = COURSE_PREREQUISITEalias0.COURSE_ID AND COURSEalias1.DEPARTMENT = \"department0\" AND COURSEalias1.NUMBER = number0 AND STUDENT_RECORDalias0.COURSE_ID = COURSEalias0.COURSE_ID AND STUDENT_RECORDalias0.STUDENT_ID = 1 ;'\n",
    "        \n",
    "        # there was not example given for level1 and hence replacing variable with values leads to errors\n",
    "        if dataset_name == 'advising' and idx == 132:\n",
    "            d['variables'][0]['example'] = '300'\n",
    "        \n",
    "        # cannot use count and order without group by; added grouping by actor_id\n",
    "        if dataset_name == 'imdb' and idx == 79:\n",
    "            d['sql'][0] = 'SELECT ACTORalias0.NAME FROM ACTOR AS ACTORalias0 , CAST AS CASTalias0 , MOVIE AS MOVIEalias0 WHERE CASTalias0.AID = ACTORalias0.AID AND MOVIEalias0.MID = CASTalias0.MSID GROUP BY ACTORalias0.AID ORDER BY COUNT( DISTINCT ( MOVIEalias0.TITLE ) ) DESC LIMIT 1 ;'\n",
    "    \n",
    "        # cannot use count and order without group by; added grouping by actor_id\n",
    "        if dataset_name == 'imdb' and idx == 80:\n",
    "            d['sql'][0] = 'SELECT ACTORalias0.NAME FROM ACTOR AS ACTORalias0 , CAST AS CASTalias0 , DIRECTED_BY AS DIRECTED_BYalias0 , DIRECTOR AS DIRECTORalias0 , MOVIE AS MOVIEalias0 WHERE CASTalias0.AID = ACTORalias0.AID AND DIRECTORalias0.DID = DIRECTED_BYalias0.DID AND MOVIEalias0.MID = CASTalias0.MSID AND MOVIEalias0.MID = DIRECTED_BYalias0.MSID GROUP BY ACTORalias0.AID ORDER BY COUNT( DISTINCT ( MOVIEalias0.TITLE ) ) DESC LIMIT 1 ;'\n",
    "        \n",
    "        # table has \"u\" in the neighborhood spelling.\n",
    "        n_before, n_after = 'NEIGHBORHOOD', 'NEIGHBOURHOOD'\n",
    "        if dataset_name == 'yelp':\n",
    "            d['sql'][0] = d['sql'][0].replace(n_before, n_after)\n",
    "        \n",
    "        if dataset_name == 'yelp' and idx == 42:\n",
    "            d['sql'][0] = 'SELECT NEIGHBOURHOODalias0.NEIGHBOURHOOD_NAME FROM BUSINESS AS BUSINESSalias0 , NEIGHBOURHOOD AS NEIGHBOURHOODalias0 , REVIEW AS REVIEWalias0 , USER AS USERalias0 WHERE NEIGHBOURHOODalias0.BUSINESS_ID = BUSINESSalias0.BUSINESS_ID AND REVIEWalias0.BUSINESS_ID = BUSINESSalias0.BUSINESS_ID AND USERalias0.NAME = \"user_name0\" AND USERalias0.USER_ID = REVIEWalias0.USER_ID ;'\n",
    "\n",
    "    orig_datasets.extend(orig_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 3509 datapoints in the new testset\n"
     ]
    }
   ],
   "source": [
    "# we create the new testset here\n",
    "new_testset = []\n",
    "for d in orig_datasets:\n",
    "    orig_id = d['orig_id']\n",
    "    db_id, idx = orig_id\n",
    "    \n",
    "    # we only incorporate the test split if the dataset is large enough\n",
    "    # otherwise we incorporate the entire dataset\n",
    "    if d['query-split'] != 'test' and db_id not in new_split_defined:\n",
    "        continue\n",
    "    sql = d['sql'][0]\n",
    "    instance_variables = d['variables']\n",
    "    instance_name2examples = {d['name']: d['example'] for d in instance_variables}\n",
    "    \n",
    "    # we create a new datapoint for each natural language query\n",
    "    for sentence in d['sentences']:\n",
    "        new_datapoint = {\n",
    "            'text': sentence['text'],\n",
    "            'query': sql,\n",
    "            'variables': instance_variables,\n",
    "            'orig_id': orig_id,\n",
    "            'db_id': db_id,\n",
    "            'db_path': 'database/{db_id}/{db_id}.sqlite'.format(db_id=db_id)\n",
    "        }\n",
    "        new_testset.append(new_datapoint)\n",
    "print('There are %d datapoints in the new testset' % len(new_testset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "# this block implements a function that extract variable names from text and sql\n",
    "# later we use it to ensure that every variable is replaced\n",
    "\n",
    "variable_pattern = re.compile('^[a-z_]+[0-9]+$')\n",
    "\n",
    "def extract_variable_names(t):\n",
    "    tokens = t.replace('\"', '').replace('%', '').split(' ')\n",
    "    var_names = {v for v in tokens if variable_pattern.match(v) and 'alias' not in v}\n",
    "    return var_names\n",
    "\n",
    "test = False\n",
    "if test:\n",
    "    sql = 'SELECT BUSINESSalias0.NAME FROM BUSINESS AS BUSINESSalias0 , REVIEW AS REVIEWalias0 WHERE REVIEWalias0.BUSINESS_ID = BUSINESSalias0.BUSINESS_ID AND REVIEWalias0.MONTH = \"review_month0\" GROUP BY BUSINESSalias0.NAME ORDER BY COUNT( DISTINCT ( REVIEWalias0.TEXT ) ) DESC LIMIT 1 ;'\n",
    "    print(extract_variable_names(sql))\n",
    "    text = 'return me the homepage of journal_name0 .'\n",
    "    print(extract_variable_names(text))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this block removes extra space surrounding variable names\n",
    "def remove_extra_space_around_variable(t):\n",
    "    var_names = extract_variable_names(t)\n",
    "    result = str(t)\n",
    "    for v in var_names:\n",
    "        result = result.replace('\" ' + v + ' \"', v)\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "set()\n"
     ]
    }
   ],
   "source": [
    "problematic = set()\n",
    "\n",
    "for datapoint in new_testset:\n",
    "    orig_id = datapoint['orig_id']\n",
    "    \n",
    "    # remove extra whitespace surrounding the text\n",
    "    datapoint['text'] = remove_extra_space_around_variable(datapoint['text'])\n",
    "    \n",
    "    # there should not be extra whitespace surrounding the sql variables\n",
    "    if datapoint['query'] != remove_extra_space_around_variable(datapoint['query']):\n",
    "        problematic.add(orig_id)\n",
    "\n",
    "    text_vars = extract_variable_names(datapoint['text'])\n",
    "    sql_vars = extract_variable_names(datapoint['query'])\n",
    "    \n",
    "    instance_variables = {d['name']: d for d in datapoint['variables']}\n",
    "    \n",
    "    # we ensure that all the variables in the sql query and the text can be replaced\n",
    "    # by some variable in the variable dictionary\n",
    "    if len(text_vars - instance_variables.keys()) != 0 or len(sql_vars - instance_variables.keys()):\n",
    "        problematic.add(orig_id)\n",
    "        \n",
    "    # replace the variables with the examples in the variable dictionary\n",
    "    for text_var in text_vars:\n",
    "        datapoint['text'] = datapoint['text'].replace(text_var, instance_variables[text_var]['example'])\n",
    "    \n",
    "    for sql_var in sql_vars:\n",
    "        datapoint['query'] = datapoint['query'].replace(sql_var, instance_variables[sql_var]['example'])\n",
    "\n",
    "# we can trace back which datapoints do not satisfy the assumption,\n",
    "# then go back and fix it manually\n",
    "print(problematic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'db_id': 'academic',\n",
      "  'db_path': 'database/academic/academic.sqlite',\n",
      "  'orig_id': ('academic', 0),\n",
      "  'query': 'SELECT JOURNALalias0.HOMEPAGE FROM JOURNAL AS JOURNALalias0 WHERE '\n",
      "           'JOURNALalias0.NAME = \"PVLDB\" ;',\n",
      "  'text': 'return me the homepage of PVLDB .',\n",
      "  'variables': [{'example': 'PVLDB',\n",
      "                 'location': 'both',\n",
      "                 'name': 'journal_name0',\n",
      "                 'type': 'journal_name'}]},\n",
      " {'db_id': 'academic',\n",
      "  'db_path': 'database/academic/academic.sqlite',\n",
      "  'orig_id': ('academic', 1),\n",
      "  'query': 'SELECT AUTHORalias0.HOMEPAGE FROM AUTHOR AS AUTHORalias0 WHERE '\n",
      "           'AUTHORalias0.NAME = \"H. V. Jagadish\" ;',\n",
      "  'text': 'return me the homepage of H. V. Jagadish .',\n",
      "  'variables': [{'example': 'H. V. Jagadish',\n",
      "                 'location': 'both',\n",
      "                 'name': 'author_name0',\n",
      "                 'type': 'author_name'}]},\n",
      " {'db_id': 'academic',\n",
      "  'db_path': 'database/academic/academic.sqlite',\n",
      "  'orig_id': ('academic', 2),\n",
      "  'query': 'SELECT PUBLICATIONalias0.ABSTRACT FROM PUBLICATION AS '\n",
      "           'PUBLICATIONalias0 WHERE PUBLICATIONalias0.TITLE = \"Making database '\n",
      "           'systems usable\" ;',\n",
      "  'text': 'return me the abstract of Making database systems usable .',\n",
      "  'variables': [{'example': 'Making database systems usable',\n",
      "                 'location': 'both',\n",
      "                 'name': 'publication_title0',\n",
      "                 'type': 'publication_title'}]},\n",
      " {'db_id': 'academic',\n",
      "  'db_path': 'database/academic/academic.sqlite',\n",
      "  'orig_id': ('academic', 3),\n",
      "  'query': 'SELECT PUBLICATIONalias0.YEAR FROM PUBLICATION AS '\n",
      "           'PUBLICATIONalias0 WHERE PUBLICATIONalias0.TITLE = \"Making database '\n",
      "           'systems usable\" ;',\n",
      "  'text': 'return me the year of Making database systems usable',\n",
      "  'variables': [{'example': 'Making database systems usable',\n",
      "                 'location': 'both',\n",
      "                 'name': 'publication_title0',\n",
      "                 'type': 'publication_title'}]},\n",
      " {'db_id': 'academic',\n",
      "  'db_path': 'database/academic/academic.sqlite',\n",
      "  'orig_id': ('academic', 3),\n",
      "  'query': 'SELECT PUBLICATIONalias0.YEAR FROM PUBLICATION AS '\n",
      "           'PUBLICATIONalias0 WHERE PUBLICATIONalias0.TITLE = \"Making database '\n",
      "           'systems usable\" ;',\n",
      "  'text': 'return me the year of Making database systems usable .',\n",
      "  'variables': [{'example': 'Making database systems usable',\n",
      "                 'location': 'both',\n",
      "                 'name': 'publication_title0',\n",
      "                 'type': 'publication_title'}]}]\n"
     ]
    }
   ],
   "source": [
    "from pprint import pprint\n",
    "\n",
    "pprint(new_testset[:5])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}