simone-papicchio commited on
Commit
2321bd0
·
1 Parent(s): ab37bbe

feat: add TQA prompt and db_schema with INSERT INTO

Browse files
prediction.py CHANGED
@@ -57,6 +57,16 @@ class ModelPrediction:
57
  "Database Schema\n"
58
  "{db_schema}\n"
59
  )
 
 
 
 
 
 
 
 
 
 
60
 
61
  def _reset_pipeline(self, model_name):
62
  if self._model_name != model_name:
@@ -74,7 +84,7 @@ class ModelPrediction:
74
  return matches[-1].strip() if matches else pred
75
 
76
 
77
- def make_prediction(self, question, db_schema, model_name, prompt=None):
78
  if model_name not in self.model_name2pred_func:
79
  raise ValueError(
80
  "Model not supported",
@@ -82,10 +92,13 @@ class ModelPrediction:
82
  self.model_name2pred_func.keys(),
83
  )
84
 
 
 
 
 
85
 
86
- prompt = prompt or self.base_prompt
87
- #prompt = prompt.format(question=question, db_schema=db_schema)
88
-
89
  start_time = time.time()
90
  prediction = self.model_name2pred_func[model_name](prompt)
91
  end_time = time.time()
@@ -93,7 +106,6 @@ class ModelPrediction:
93
  prediction["response"]
94
  )
95
  prediction['time'] = end_time - start_time
96
-
97
  return prediction
98
 
99
 
@@ -133,14 +145,12 @@ class ModelPrediction:
133
  model_name = "openai/gpt-3.5-turbo-0125"
134
  elif "gpt-4o-mini" in model_name:
135
  model_name = "openai/gpt-4o-mini-2024-07-18"
136
- elif "o1-mini" in model_name:
137
- model_name = "openai/o1-mini-2024-09-12"
138
- elif "QwQ" in model_name:
139
- model_name = "together_ai/Qwen/QwQ-32B"
140
  elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
141
  model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
142
  elif "llama-8" in model_name:
143
  model_name = "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
 
 
144
  else:
145
  raise ValueError("Model forbidden")
146
 
 
57
  "Database Schema\n"
58
  "{db_schema}\n"
59
  )
60
+ self.base_prompt_QA= (
61
+ "Return the answer of the following question based on the provided database."
62
+ " Return your answer as the result of a query executed over the database."
63
+ " Namely, as a list of list where the first list represent the tuples and the second list the values in that tuple.\n"
64
+ "Return the answer in answer tag as <answer> </answer>"
65
+ " Question\n"
66
+ "{question}\n"
67
+ "Database Schema\n"
68
+ "{db_schema}\n"
69
+ )
70
 
71
  def _reset_pipeline(self, model_name):
72
  if self._model_name != model_name:
 
84
  return matches[-1].strip() if matches else pred
85
 
86
 
87
+ def make_prediction(self, question, db_schema, model_name, prompt=None, task='SP'):
88
  if model_name not in self.model_name2pred_func:
89
  raise ValueError(
90
  "Model not supported",
 
92
  self.model_name2pred_func.keys(),
93
  )
94
 
95
+ if task == 'SP':
96
+ prompt = prompt or self.base_prompt
97
+ else:
98
+ prompt = prompt or self.base_prompt_QA
99
 
100
+ prompt = prompt.format(question=question, db_schema=db_schema)
101
+
 
102
  start_time = time.time()
103
  prediction = self.model_name2pred_func[model_name](prompt)
104
  end_time = time.time()
 
106
  prediction["response"]
107
  )
108
  prediction['time'] = end_time - start_time
 
109
  return prediction
110
 
111
 
 
145
  model_name = "openai/gpt-3.5-turbo-0125"
146
  elif "gpt-4o-mini" in model_name:
147
  model_name = "openai/gpt-4o-mini-2024-07-18"
 
 
 
 
148
  elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
149
  model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
150
  elif "llama-8" in model_name:
151
  model_name = "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
152
+ elif "llama-70" in model_name:
153
+ model_name = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
154
  else:
155
  raise ValueError("Model forbidden")
156
 
test_get_db_schema_with_entries.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils_get_db_tables_info import utils_extract_db_schema_as_string
2
+
3
+ def main():
4
+ db = utils_extract_db_schema_as_string(
5
+ db_id='',
6
+ base_path='mytable_7.sqlite',
7
+ normalize=True,
8
+ sql=None,
9
+ get_insert_into=True
10
+ )
11
+
12
+ print(db)
13
+
14
+ if __name__ == '__main__':
15
+ main()
utils_get_db_tables_info.py CHANGED
@@ -4,7 +4,7 @@ import re
4
 
5
 
6
  def utils_extract_db_schema_as_string(
7
- db_id, base_path, normalize=False, sql: str | None = None
8
  ):
9
  """
10
  Extracts the full schema of an SQLite database into a single string.
@@ -19,7 +19,7 @@ def utils_extract_db_schema_as_string(
19
  cursor = connection.cursor()
20
 
21
  # Get the schema entries based on the provided SQL query
22
- schema_entries = _get_schema_entries(cursor, sql)
23
 
24
  # Combine all schema definitions into a single string
25
  schema_string = _combine_schema_entries(schema_entries, normalize)
@@ -27,27 +27,47 @@ def utils_extract_db_schema_as_string(
27
  return schema_string
28
 
29
 
30
- def _get_schema_entries(cursor, sql):
 
31
  """
32
- Retrieves schema entries from the SQLite database.
33
 
34
  :param cursor: SQLite cursor object.
35
  :param sql: Optional SQL query to filter specific tables.
36
- :return: List of schema entries.
 
37
  """
 
 
38
  if sql:
 
39
  cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
40
  tables = [tbl[0] for tbl in cursor.fetchall() if tbl[0].lower() in sql.lower()]
41
- if tables:
42
- tbl_names = ", ".join(f"'{tbl}'" for tbl in tables)
43
- query = f"SELECT sql FROM sqlite_master WHERE type='table' AND name IN ({tbl_names}) AND sql IS NOT NULL;"
44
- else:
45
- query = "SELECT sql FROM sqlite_master WHERE sql IS NOT NULL;"
46
  else:
47
- query = "SELECT sql FROM sqlite_master WHERE sql IS NOT NULL;"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- cursor.execute(query)
50
- return cursor.fetchall()
51
 
52
 
53
  def _combine_schema_entries(schema_entries, normalize):
@@ -59,7 +79,7 @@ def _combine_schema_entries(schema_entries, normalize):
59
  :return: Combined schema string.
60
  """
61
  if not normalize:
62
- return "\n".join(entry[0] for entry in schema_entries)
63
 
64
  return "\n".join(
65
  re.sub(
@@ -77,7 +97,7 @@ def _combine_schema_entries(schema_entries, normalize):
77
  re.sub(
78
  r"\s+",
79
  " ",
80
- entry[0].replace("CREATE TABLE", "").replace("\t", " "),
81
  ).strip(),
82
  ),
83
  ),
 
4
 
5
 
6
  def utils_extract_db_schema_as_string(
7
+ db_id, base_path, normalize=False, sql: str | None = None, get_insert_into: bool = False
8
  ):
9
  """
10
  Extracts the full schema of an SQLite database into a single string.
 
19
  cursor = connection.cursor()
20
 
21
  # Get the schema entries based on the provided SQL query
22
+ schema_entries = _get_schema_entries(cursor, sql, get_insert_into)
23
 
24
  # Combine all schema definitions into a single string
25
  schema_string = _combine_schema_entries(schema_entries, normalize)
 
27
  return schema_string
28
 
29
 
30
+
31
+ def _get_schema_entries(cursor, sql=None, get_insert_into=False):
32
  """
33
+ Retrieves schema entries and optionally data entries from the SQLite database.
34
 
35
  :param cursor: SQLite cursor object.
36
  :param sql: Optional SQL query to filter specific tables.
37
+ :param get_insert_into: Boolean flag to include INSERT INTO statements.
38
+ :return: List of schema and optionally data entries.
39
  """
40
+ entries = []
41
+
42
  if sql:
43
+ # Extract table names from the provided SQL query
44
  cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
45
  tables = [tbl[0] for tbl in cursor.fetchall() if tbl[0].lower() in sql.lower()]
 
 
 
 
 
46
  else:
47
+ # Retrieve all table names
48
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
49
+ tables = [tbl[0] for tbl in cursor.fetchall()]
50
+
51
+ for table in tables:
52
+ # Retrieve the CREATE TABLE statement for each table
53
+ cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}' AND sql IS NOT NULL;")
54
+ create_table_stmt = cursor.fetchone()
55
+ if create_table_stmt:
56
+ entries.append(create_table_stmt[0])
57
+
58
+ if get_insert_into:
59
+ # Retrieve all data from the table
60
+ cursor.execute(f"SELECT * FROM {table};")
61
+ rows = cursor.fetchall()
62
+ column_names = [description[0] for description in cursor.description]
63
+
64
+ # Generate INSERT INTO statements for each row
65
+ for row in rows:
66
+ values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row)
67
+ insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});"
68
+ entries.append(insert_stmt)
69
 
70
+ return entries
 
71
 
72
 
73
  def _combine_schema_entries(schema_entries, normalize):
 
79
  :return: Combined schema string.
80
  """
81
  if not normalize:
82
+ return "\n".join(entry for entry in schema_entries)
83
 
84
  return "\n".join(
85
  re.sub(
 
97
  re.sub(
98
  r"\s+",
99
  " ",
100
+ entry.replace("CREATE TABLE", "").replace("\t", " "),
101
  ).strip(),
102
  ),
103
  ),