simone-papicchio commited on
Commit
d7c9e73
·
1 Parent(s): 8c0ca3e

feat add prediction with togheter AI and HF pipe

Browse files
Files changed (2) hide show
  1. prediction.py +105 -23
  2. requirements.txt +1 -1
prediction.py CHANGED
@@ -1,45 +1,127 @@
 
 
 
 
 
1
  # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
2
  if os.environ.get("SPACES_ZERO_GPU") is not None:
3
  import spaces
4
  else:
 
5
  class spaces:
6
  @staticmethod
7
  def GPU(func):
8
  def wrapper(*args, **kwargs):
9
  return func(*args, **kwargs)
 
10
  return wrapper
11
 
12
 
 
 
 
 
13
 
14
  class ModelPrediction:
15
  def __init__(self, model_name):
16
- self.prediction_fun = self._model_prediction(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- def make_prediction(prompt):
19
- pass
20
-
21
  def _model_prediction(self, model_name):
22
- predict_fun = predict_with_api
23
- if 'gpt-3.5' in model_name:
24
- model_name = 'openai/gpt-3.5-turbo-0125'
25
- elif 'gpt-4o-mini' in model_name:
26
- model_name = 'openai/gpt-4o-mini-2024-07-18'
27
- elif 'o1-mini' in model_name:
28
- model_name = 'openai/o1-mini-2024-09-12'
29
- elif 'QwQ' in model_name:
30
- model_name = 'together_ai/Qwen/QwQ-32B'
31
- elif 'DeepSeek-R1-Distill-Llama-70B' in model_name:
32
- model_name = 'together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B'
33
  else:
34
- raise ValueError('Model forbidden')
35
-
36
- return
37
 
38
- def predict_with_api(prompt):
39
- pass
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- @spaces.GPU
43
- def predict_with_hf(prompt):
44
- pass
 
 
 
 
 
45
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import os
3
+ import re
4
+ from xml.parsers.expat import model
5
+
6
  # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
7
  if os.environ.get("SPACES_ZERO_GPU") is not None:
8
  import spaces
9
  else:
10
+
11
  class spaces:
12
  @staticmethod
13
  def GPU(func):
14
  def wrapper(*args, **kwargs):
15
  return func(*args, **kwargs)
16
+
17
  return wrapper
18
 
19
 
20
+ from transformers import pipeline as hf_pipeline
21
+ import torch
22
+ import litellm
23
+
24
 
25
  class ModelPrediction:
26
  def __init__(self, model_name):
27
+ self.model_name2pred_func = {
28
+ "gpt-3.5": self._model_prediction("gpt-3.5"),
29
+ "gpt-4o-mini": self._model_prediction("gpt-4o-mini"),
30
+ "o1-mini": self._model_prediction("o1-mini"),
31
+ "QwQ": self._model_prediction("QwQ"),
32
+ "DeepSeek-R1-Distill-Llama-70B": self._model_prediction(
33
+ "DeepSeek-R1-Distill-Llama-70B"
34
+ ),
35
+ }
36
+
37
+ self._model_name = None
38
+ self._pipeline = None
39
+
40
+ @property
41
+ def pipeline(self):
42
+ if self._pipeline is None:
43
+ self._pipeline = hf_pipeline(
44
+ task="text-generation",
45
+ model=self._model_name,
46
+ torch_dtype=torch.bfloat16,
47
+ device_map="auto",
48
+ )
49
+ return self._pipeline
50
+
51
+ def _reset_pipeline(self, model_name):
52
+ if self._model_name != model_name:
53
+ self._model_name = model_name
54
+ self._pipeline = None
55
+
56
+ @staticmethod
57
+ def _extract_answer_from_pred(pred: str) -> str:
58
+ # extract with regex everything is between <answer> and </answer>
59
+ matches = re.findall(r"<answer>(.*?)</answer>", pred, re.DOTALL)
60
+ if matches:
61
+ return matches[-1].replace("```", "").replace("sql", "").strip()
62
+ else:
63
+ matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL)
64
+ return matches[-1].strip() if matches else pred
65
+
66
+ def make_prediction(self, prompt, model_name):
67
+ if model_name not in self.model_name2pred_func:
68
+ raise ValueError(
69
+ "Model not supported",
70
+ "supported models are",
71
+ self.model_name2pred_func.keys(),
72
+ )
73
+
74
+ prediction = self.model_name2pred_func[model_name](prompt)
75
+ prediction["response_parsed"] = self._extract_answer_from_pred(
76
+ prediction["response"]
77
+ )
78
+ return prediction
79
 
 
 
 
80
  def _model_prediction(self, model_name):
81
+ predict_fun = self.predict_with_api
82
+ if "gpt-3.5" in model_name:
83
+ model_name = "openai/gpt-3.5-turbo-0125"
84
+ elif "gpt-4o-mini" in model_name:
85
+ model_name = "openai/gpt-4o-mini-2024-07-18"
86
+ elif "o1-mini" in model_name:
87
+ model_name = "openai/o1-mini-2024-09-12"
88
+ elif "QwQ" in model_name:
89
+ model_name = "together_ai/Qwen/QwQ-32B"
90
+ elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
91
+ model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
92
  else:
93
+ raise ValueError("Model forbidden")
 
 
94
 
95
+ return partial(predict_fun, model_name=model_name)
 
96
 
97
+ def predict_with_api(self, prompt, model_name): # -> dict[str, Any | float]:
98
+ def track_cost_callback(
99
+ kwargs, # kwargs to completion
100
+ completion_response, # response from completion
101
+ start_time,
102
+ end_time, # start/end time
103
+ ):
104
+ try:
105
+ response_cost = kwargs[
106
+ "response_cost"
107
+ ] # litellm calculates response cost for you
108
+ call_cost = response_cost
109
+ except:
110
+ pass
111
 
112
+ litellm.success_callback = [track_cost_callback]
113
+ call_cost = 0.0
114
+ response = litellm.completion(
115
+ model=model_name,
116
+ messages=[{"role": "user", "content": prompt}],
117
+ num_retries=2,
118
+ )
119
+ return {"response": response, "cost": call_cost}
120
 
121
+ @spaces.GPU
122
+ def predict_with_hf(self, prompt, model_name): # -> dict[str, Any | float]:
123
+ self._reset_pipeline(model_name)
124
+ response = self.pipeline([{"role": "user", "content": prompt}])[0][
125
+ "generated_text"
126
+ ][-1]["content"]
127
+ return {"response": response, "cost": 0.0}
requirements.txt CHANGED
@@ -10,7 +10,7 @@ eval-type-backport>=0.2.0
10
  openai==1.66.3
11
  litellm==1.63.14
12
  together==1.4.6
13
-
14
  # Conditional dependency for Gradio (requires Python >=3.10)
15
  gradio>=5.20.1; python_version >= "3.10"
16
 
 
10
  openai==1.66.3
11
  litellm==1.63.14
12
  together==1.4.6
13
+ litellm==1.63.14
14
  # Conditional dependency for Gradio (requires Python >=3.10)
15
  gradio>=5.20.1; python_version >= "3.10"
16