Spaces:
Sleeping
Sleeping
File size: 4,765 Bytes
d7c9e73 8277386 d7c9e73 c2e2aa2 d7c9e73 c2e2aa2 d7c9e73 c2e2aa2 d7c9e73 6ce82f5 d7c9e73 ffec641 c2e2aa2 220b4dd d7c9e73 6ce82f5 d7c9e73 6ce82f5 d7c9e73 6ce82f5 d7c9e73 ffec641 6ce82f5 d7c9e73 8277386 d7c9e73 6ce82f5 8277386 d7c9e73 8277386 d7c9e73 8277386 d7c9e73 c2e2aa2 6ce82f5 d7c9e73 ffec641 c2e2aa2 d7c9e73 6ce82f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from functools import partial
import os
import re
import time
from xml.parsers.expat import model
# https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
if os.environ.get("SPACES_ZERO_GPU") is not None:
import spaces
else:
class spaces:
@staticmethod
def GPU(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
from transformers import pipeline as hf_pipeline
import litellm
from tqdm import tqdm
class ModelPrediction:
def __init__(self):
self.model_name2pred_func = {
"gpt-3.5": self._init_model_prediction("gpt-3.5"),
"gpt-4o-mini": self._init_model_prediction("gpt-4o-mini"),
"o1-mini": self._init_model_prediction("o1-mini"),
"QwQ": self._init_model_prediction("QwQ"),
"DeepSeek-R1-Distill-Llama-70B": self._init_model_prediction(
"DeepSeek-R1-Distill-Llama-70B"
),
"llama-8": self._init_model_prediction("llama-8"),
}
self._model_name = None
self._pipeline = None
self.base_prompt= (
"Translate the following question in SQL code to be executed over the database to fetch the answer. Return the sql code in ```sql ```\n"
" Question\n"
"{question}\n"
"Database Schema\n"
"{db_schema}\n"
)
@property
def pipeline(self):
if self._pipeline is None:
self._pipeline = hf_pipeline(
task="text-generation",
model=self._model_name,
device_map="auto",
)
return self._pipeline
def _reset_pipeline(self, model_name):
if self._model_name != model_name:
self._model_name = model_name
self._pipeline = None
@staticmethod
def _extract_answer_from_pred(pred: str) -> str:
# extract with regex everything is between <answer> and </answer>
matches = re.findall(r"<answer>(.*?)</answer>", pred, re.DOTALL)
if matches:
return matches[-1].replace("```", "").replace("sql", "").strip()
else:
matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL)
return matches[-1].strip() if matches else pred
def make_prediction(self, question, db_schema, model_name, prompt=None):
if model_name not in self.model_name2pred_func:
raise ValueError(
"Model not supported",
"supported models are",
self.model_name2pred_func.keys(),
)
prompt = prompt or self.base_prompt
prompt = prompt.format(question=question, db_schema=db_schema)
start_time = time.time()
prediction = self.model_name2pred_func[model_name](prompt)
end_time = time.time()
prediction["response_parsed"] = self._extract_answer_from_pred(
prediction["response"]
)
prediction['time'] = end_time - start_time
return prediction
def predict_with_api(self, prompt, model_name): # -> dict[str, Any | float]:
response = litellm.completion(
model=model_name,
messages=[{"role": "user", "content": prompt}],
num_retries=2,
)
response_text = response["choices"][0]["message"]["content"]
return {
"response": response_text,
"cost": response._hidden_params["response_cost"],
}
@spaces.GPU
def predict_with_hf(self, prompt, model_name): # -> dict[str, Any | float]:
self._reset_pipeline(model_name)
response = self.pipeline([{"role": "user", "content": prompt}])[0][
"generated_text"
][-1]["content"]
return {"response": response, "cost": 0.0}
def _init_model_prediction(self, model_name):
predict_fun = self.predict_with_api
if "gpt-3.5" in model_name:
model_name = "openai/gpt-3.5-turbo-0125"
elif "gpt-4o-mini" in model_name:
model_name = "openai/gpt-4o-mini-2024-07-18"
elif "o1-mini" in model_name:
model_name = "openai/o1-mini-2024-09-12"
elif "QwQ" in model_name:
model_name = "together_ai/Qwen/QwQ-32B"
elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
elif "llama-8" in model_name:
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
predict_fun = self.predict_with_hf
else:
raise ValueError("Model forbidden")
return partial(predict_fun, model_name=model_name)
|