simone-papicchio commited on
Commit
ffec641
·
1 Parent(s): 220b4dd

feat add model on zeroGpu

Browse files
Files changed (3) hide show
  1. prediction.py +19 -17
  2. requirements.txt +1 -1
  3. test_prediction.py +4 -2
prediction.py CHANGED
@@ -21,6 +21,8 @@ from transformers import pipeline as hf_pipeline
21
  import torch
22
  import litellm
23
 
 
 
24
 
25
  class ModelPrediction:
26
  def __init__(self):
@@ -32,6 +34,7 @@ class ModelPrediction:
32
  "DeepSeek-R1-Distill-Llama-70B": self._model_prediction(
33
  "DeepSeek-R1-Distill-Llama-70B"
34
  ),
 
35
  }
36
 
37
  self._model_name = None
@@ -50,6 +53,7 @@ class ModelPrediction:
50
 
51
  def _reset_pipeline(self, model_name):
52
  if self._model_name != model_name:
 
53
  self._model_name = model_name
54
  self._pipeline = None
55
 
@@ -63,6 +67,13 @@ class ModelPrediction:
63
  matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL)
64
  return matches[-1].strip() if matches else pred
65
 
 
 
 
 
 
 
 
66
  def make_prediction(self, prompt, model_name):
67
  if model_name not in self.model_name2pred_func:
68
  raise ValueError(
@@ -89,34 +100,25 @@ class ModelPrediction:
89
  model_name = "together_ai/Qwen/QwQ-32B"
90
  elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
91
  model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
 
 
 
92
  else:
93
  raise ValueError("Model forbidden")
94
 
95
  return partial(predict_fun, model_name=model_name)
96
 
97
  def predict_with_api(self, prompt, model_name): # -> dict[str, Any | float]:
98
- def track_cost_callback(
99
- kwargs, # kwargs to completion
100
- completion_response, # response from completion
101
- start_time,
102
- end_time, # start/end time
103
- ):
104
- try:
105
- response_cost = kwargs[
106
- "response_cost"
107
- ] # litellm calculates response cost for you
108
- call_cost = response_cost
109
- except:
110
- pass
111
-
112
- litellm.success_callback = [track_cost_callback]
113
- call_cost = 0.0
114
  response = litellm.completion(
115
  model=model_name,
116
  messages=[{"role": "user", "content": prompt}],
117
  num_retries=2,
118
  )
119
- return {"response": response, "cost": call_cost}
 
 
 
 
120
 
121
  @spaces.GPU
122
  def predict_with_hf(self, prompt, model_name): # -> dict[str, Any | float]:
 
21
  import torch
22
  import litellm
23
 
24
+ from tqdm import tqdm
25
+
26
 
27
  class ModelPrediction:
28
  def __init__(self):
 
34
  "DeepSeek-R1-Distill-Llama-70B": self._model_prediction(
35
  "DeepSeek-R1-Distill-Llama-70B"
36
  ),
37
+ "llama-8": self._model_prediction("llama-8"),
38
  }
39
 
40
  self._model_name = None
 
53
 
54
  def _reset_pipeline(self, model_name):
55
  if self._model_name != model_name:
56
+ print("Resetting pipeline with model", model_name)
57
  self._model_name = model_name
58
  self._pipeline = None
59
 
 
67
  matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL)
68
  return matches[-1].strip() if matches else pred
69
 
70
+ def make_predictions(self, prompts, model_name) -> list[dict]:
71
+ preds = []
72
+ for prompt in tqdm(prompts, desc=f"Analyzing Prompt with {model_name}"):
73
+ pred = self.make_prediction(prompt, model_name)
74
+ preds.append(pred)
75
+ return preds
76
+
77
  def make_prediction(self, prompt, model_name):
78
  if model_name not in self.model_name2pred_func:
79
  raise ValueError(
 
100
  model_name = "together_ai/Qwen/QwQ-32B"
101
  elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
102
  model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
103
+ elif "llama-8" in model_name:
104
+ model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
105
+ predict_fun = self.predict_with_hf
106
  else:
107
  raise ValueError("Model forbidden")
108
 
109
  return partial(predict_fun, model_name=model_name)
110
 
111
  def predict_with_api(self, prompt, model_name): # -> dict[str, Any | float]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  response = litellm.completion(
113
  model=model_name,
114
  messages=[{"role": "user", "content": prompt}],
115
  num_retries=2,
116
  )
117
+ response_text = response["choices"][0]["message"]["content"]
118
+ return {
119
+ "response": response_text,
120
+ "cost": response._hidden_params["response_cost"],
121
+ }
122
 
123
  @spaces.GPU
124
  def predict_with_hf(self, prompt, model_name): # -> dict[str, Any | float]:
requirements.txt CHANGED
@@ -10,9 +10,9 @@ eval-type-backport>=0.2.0
10
  openai==1.66.3
11
  litellm==1.63.14
12
  together==1.4.6
13
- litellm==1.63.14
14
  # Conditional dependency for Gradio (requires Python >=3.10)
15
  gradio>=5.20.1; python_version >= "3.10"
 
16
 
17
  # Test dependencies
18
  streamlit>=1.43.0
 
10
  openai==1.66.3
11
  litellm==1.63.14
12
  together==1.4.6
 
13
  # Conditional dependency for Gradio (requires Python >=3.10)
14
  gradio>=5.20.1; python_version >= "3.10"
15
+ accelerate>=0.26.0
16
 
17
  # Test dependencies
18
  streamlit>=1.43.0
test_prediction.py CHANGED
@@ -3,8 +3,10 @@ from prediction import ModelPrediction
3
 
4
  def main():
5
  model = ModelPrediction()
6
- response = model.make_prediction("Hi, how are you?", "gpt-3.5")
7
- print(response)
 
8
 
9
  if __name__ == "__main__":
10
  main()
 
 
3
 
4
  def main():
5
  model = ModelPrediction()
6
+ response = model.make_prediction("Hi, how are you?", "llama-8")
7
+ print(response) # dict[response, response_parsed, cost]
8
+
9
 
10
  if __name__ == "__main__":
11
  main()
12
+ # do something with prompt