qatch-demo / prediction.py
simone-papicchio's picture
feat add model on zeroGpu
ffec641
raw
history blame
4.54 kB
from functools import partial
import os
import re
from xml.parsers.expat import model
# https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
if os.environ.get("SPACES_ZERO_GPU") is not None:
import spaces
else:
class spaces:
@staticmethod
def GPU(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
from transformers import pipeline as hf_pipeline
import torch
import litellm
from tqdm import tqdm
class ModelPrediction:
def __init__(self):
self.model_name2pred_func = {
"gpt-3.5": self._model_prediction("gpt-3.5"),
"gpt-4o-mini": self._model_prediction("gpt-4o-mini"),
"o1-mini": self._model_prediction("o1-mini"),
"QwQ": self._model_prediction("QwQ"),
"DeepSeek-R1-Distill-Llama-70B": self._model_prediction(
"DeepSeek-R1-Distill-Llama-70B"
),
"llama-8": self._model_prediction("llama-8"),
}
self._model_name = None
self._pipeline = None
@property
def pipeline(self):
if self._pipeline is None:
self._pipeline = hf_pipeline(
task="text-generation",
model=self._model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
)
return self._pipeline
def _reset_pipeline(self, model_name):
if self._model_name != model_name:
print("Resetting pipeline with model", model_name)
self._model_name = model_name
self._pipeline = None
@staticmethod
def _extract_answer_from_pred(pred: str) -> str:
# extract with regex everything is between <answer> and </answer>
matches = re.findall(r"<answer>(.*?)</answer>", pred, re.DOTALL)
if matches:
return matches[-1].replace("```", "").replace("sql", "").strip()
else:
matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL)
return matches[-1].strip() if matches else pred
def make_predictions(self, prompts, model_name) -> list[dict]:
preds = []
for prompt in tqdm(prompts, desc=f"Analyzing Prompt with {model_name}"):
pred = self.make_prediction(prompt, model_name)
preds.append(pred)
return preds
def make_prediction(self, prompt, model_name):
if model_name not in self.model_name2pred_func:
raise ValueError(
"Model not supported",
"supported models are",
self.model_name2pred_func.keys(),
)
prediction = self.model_name2pred_func[model_name](prompt)
prediction["response_parsed"] = self._extract_answer_from_pred(
prediction["response"]
)
return prediction
def _model_prediction(self, model_name):
predict_fun = self.predict_with_api
if "gpt-3.5" in model_name:
model_name = "openai/gpt-3.5-turbo-0125"
elif "gpt-4o-mini" in model_name:
model_name = "openai/gpt-4o-mini-2024-07-18"
elif "o1-mini" in model_name:
model_name = "openai/o1-mini-2024-09-12"
elif "QwQ" in model_name:
model_name = "together_ai/Qwen/QwQ-32B"
elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
elif "llama-8" in model_name:
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
predict_fun = self.predict_with_hf
else:
raise ValueError("Model forbidden")
return partial(predict_fun, model_name=model_name)
def predict_with_api(self, prompt, model_name): # -> dict[str, Any | float]:
response = litellm.completion(
model=model_name,
messages=[{"role": "user", "content": prompt}],
num_retries=2,
)
response_text = response["choices"][0]["message"]["content"]
return {
"response": response_text,
"cost": response._hidden_params["response_cost"],
}
@spaces.GPU
def predict_with_hf(self, prompt, model_name): # -> dict[str, Any | float]:
self._reset_pipeline(model_name)
response = self.pipeline([{"role": "user", "content": prompt}])[0][
"generated_text"
][-1]["content"]
return {"response": response, "cost": 0.0}