simone-papicchio commited on
Commit
6ce82f5
·
1 Parent(s): ffec641

feat add model prediction for text2sql prompt

Browse files
Files changed (2) hide show
  1. prediction.py +41 -35
  2. test_prediction.py +1 -1
prediction.py CHANGED
@@ -18,7 +18,7 @@ else:
18
 
19
 
20
  from transformers import pipeline as hf_pipeline
21
- import torch
22
  import litellm
23
 
24
  from tqdm import tqdm
@@ -27,18 +27,25 @@ from tqdm import tqdm
27
  class ModelPrediction:
28
  def __init__(self):
29
  self.model_name2pred_func = {
30
- "gpt-3.5": self._model_prediction("gpt-3.5"),
31
- "gpt-4o-mini": self._model_prediction("gpt-4o-mini"),
32
- "o1-mini": self._model_prediction("o1-mini"),
33
- "QwQ": self._model_prediction("QwQ"),
34
- "DeepSeek-R1-Distill-Llama-70B": self._model_prediction(
35
  "DeepSeek-R1-Distill-Llama-70B"
36
  ),
37
- "llama-8": self._model_prediction("llama-8"),
38
  }
39
 
40
  self._model_name = None
41
  self._pipeline = None
 
 
 
 
 
 
 
42
 
43
  @property
44
  def pipeline(self):
@@ -46,7 +53,6 @@ class ModelPrediction:
46
  self._pipeline = hf_pipeline(
47
  task="text-generation",
48
  model=self._model_name,
49
- torch_dtype=torch.bfloat16,
50
  device_map="auto",
51
  )
52
  return self._pipeline
@@ -67,14 +73,8 @@ class ModelPrediction:
67
  matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL)
68
  return matches[-1].strip() if matches else pred
69
 
70
- def make_predictions(self, prompts, model_name) -> list[dict]:
71
- preds = []
72
- for prompt in tqdm(prompts, desc=f"Analyzing Prompt with {model_name}"):
73
- pred = self.make_prediction(prompt, model_name)
74
- preds.append(pred)
75
- return preds
76
 
77
- def make_prediction(self, prompt, model_name):
78
  if model_name not in self.model_name2pred_func:
79
  raise ValueError(
80
  "Model not supported",
@@ -82,32 +82,17 @@ class ModelPrediction:
82
  self.model_name2pred_func.keys(),
83
  )
84
 
 
 
 
 
85
  prediction = self.model_name2pred_func[model_name](prompt)
86
  prediction["response_parsed"] = self._extract_answer_from_pred(
87
  prediction["response"]
88
  )
89
  return prediction
90
 
91
- def _model_prediction(self, model_name):
92
- predict_fun = self.predict_with_api
93
- if "gpt-3.5" in model_name:
94
- model_name = "openai/gpt-3.5-turbo-0125"
95
- elif "gpt-4o-mini" in model_name:
96
- model_name = "openai/gpt-4o-mini-2024-07-18"
97
- elif "o1-mini" in model_name:
98
- model_name = "openai/o1-mini-2024-09-12"
99
- elif "QwQ" in model_name:
100
- model_name = "together_ai/Qwen/QwQ-32B"
101
- elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
102
- model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
103
- elif "llama-8" in model_name:
104
- model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
105
- predict_fun = self.predict_with_hf
106
- else:
107
- raise ValueError("Model forbidden")
108
-
109
- return partial(predict_fun, model_name=model_name)
110
-
111
  def predict_with_api(self, prompt, model_name): # -> dict[str, Any | float]:
112
  response = litellm.completion(
113
  model=model_name,
@@ -127,3 +112,24 @@ class ModelPrediction:
127
  "generated_text"
128
  ][-1]["content"]
129
  return {"response": response, "cost": 0.0}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  from transformers import pipeline as hf_pipeline
21
+
22
  import litellm
23
 
24
  from tqdm import tqdm
 
27
  class ModelPrediction:
28
  def __init__(self):
29
  self.model_name2pred_func = {
30
+ "gpt-3.5": self._init_model_prediction("gpt-3.5"),
31
+ "gpt-4o-mini": self._init_model_prediction("gpt-4o-mini"),
32
+ "o1-mini": self._init_model_prediction("o1-mini"),
33
+ "QwQ": self._init_model_prediction("QwQ"),
34
+ "DeepSeek-R1-Distill-Llama-70B": self._init_model_prediction(
35
  "DeepSeek-R1-Distill-Llama-70B"
36
  ),
37
+ "llama-8": self._init_model_prediction("llama-8"),
38
  }
39
 
40
  self._model_name = None
41
  self._pipeline = None
42
+ self.base_prompt= (
43
+ "Translate the following question in SQL code to be executed over the database to fetch the answer. Return the sql code in ```sql ```\n"
44
+ " Question\n"
45
+ "{question}\n"
46
+ "Database Schema\n"
47
+ "{db_schema}\n"
48
+ )
49
 
50
  @property
51
  def pipeline(self):
 
53
  self._pipeline = hf_pipeline(
54
  task="text-generation",
55
  model=self._model_name,
 
56
  device_map="auto",
57
  )
58
  return self._pipeline
 
73
  matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL)
74
  return matches[-1].strip() if matches else pred
75
 
 
 
 
 
 
 
76
 
77
+ def make_prediction(self, question, db_schema, model_name, prompt=None):
78
  if model_name not in self.model_name2pred_func:
79
  raise ValueError(
80
  "Model not supported",
 
82
  self.model_name2pred_func.keys(),
83
  )
84
 
85
+ prompt = prompt or self.base_prompt
86
+ prompt = prompt.format(question=question, db_schema=db_schema)
87
+ print(prompt)
88
+
89
  prediction = self.model_name2pred_func[model_name](prompt)
90
  prediction["response_parsed"] = self._extract_answer_from_pred(
91
  prediction["response"]
92
  )
93
  return prediction
94
 
95
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def predict_with_api(self, prompt, model_name): # -> dict[str, Any | float]:
97
  response = litellm.completion(
98
  model=model_name,
 
112
  "generated_text"
113
  ][-1]["content"]
114
  return {"response": response, "cost": 0.0}
115
+
116
+ def _init_model_prediction(self, model_name):
117
+ predict_fun = self.predict_with_api
118
+ if "gpt-3.5" in model_name:
119
+ model_name = "openai/gpt-3.5-turbo-0125"
120
+ elif "gpt-4o-mini" in model_name:
121
+ model_name = "openai/gpt-4o-mini-2024-07-18"
122
+ elif "o1-mini" in model_name:
123
+ model_name = "openai/o1-mini-2024-09-12"
124
+ elif "QwQ" in model_name:
125
+ model_name = "together_ai/Qwen/QwQ-32B"
126
+ elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
127
+ model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
128
+ elif "llama-8" in model_name:
129
+ model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
130
+ predict_fun = self.predict_with_hf
131
+ else:
132
+ raise ValueError("Model forbidden")
133
+
134
+ return partial(predict_fun, model_name=model_name)
135
+
test_prediction.py CHANGED
@@ -3,7 +3,7 @@ from prediction import ModelPrediction
3
 
4
  def main():
5
  model = ModelPrediction()
6
- response = model.make_prediction("Hi, how are you?", "llama-8")
7
  print(response) # dict[response, response_parsed, cost]
8
 
9
 
 
3
 
4
  def main():
5
  model = ModelPrediction()
6
+ response = model.make_prediction(question='What is the name of Simone', db_schema='CREATE TABLE Player(Name, Age)', model_name="gpt-3.5", prompt='{question} {db_schema}')
7
  print(response) # dict[response, response_parsed, cost]
8
 
9