from functools import partial import os import re import time from xml.parsers.expat import model # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10 if os.environ.get("SPACES_ZERO_GPU") is not None: import spaces else: class spaces: @staticmethod def GPU(func): def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper from transformers import pipeline as hf_pipeline import litellm from tqdm import tqdm class ModelPrediction: def __init__(self): self.model_name2pred_func = { "gpt-3.5": self._init_model_prediction("gpt-3.5"), "gpt-4o-mini": self._init_model_prediction("gpt-4o-mini"), "o1-mini": self._init_model_prediction("o1-mini"), "QwQ": self._init_model_prediction("QwQ"), "DeepSeek-R1-Distill-Llama-70B": self._init_model_prediction( "DeepSeek-R1-Distill-Llama-70B" ), "llama-8": self._init_model_prediction("llama-8"), } self._model_name = None self._pipeline = None self.base_prompt= ( "Translate the following question in SQL code to be executed over the database to fetch the answer. Return the sql code in ```sql ```\n" " Question\n" "{question}\n" "Database Schema\n" "{db_schema}\n" ) @property def pipeline(self): if self._pipeline is None: self._pipeline = hf_pipeline( task="text-generation", model=self._model_name, device_map="auto", ) return self._pipeline def _reset_pipeline(self, model_name): if self._model_name != model_name: self._model_name = model_name self._pipeline = None @staticmethod def _extract_answer_from_pred(pred: str) -> str: # extract with regex everything is between and matches = re.findall(r"(.*?)", pred, re.DOTALL) if matches: return matches[-1].replace("```", "").replace("sql", "").strip() else: matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL) return matches[-1].strip() if matches else pred def make_prediction(self, question, db_schema, model_name, prompt=None): if model_name not in self.model_name2pred_func: raise ValueError( "Model not supported", "supported models are", self.model_name2pred_func.keys(), ) prompt = prompt or self.base_prompt prompt = prompt.format(question=question, db_schema=db_schema) start_time = time.time() prediction = self.model_name2pred_func[model_name](prompt) end_time = time.time() prediction["response_parsed"] = self._extract_answer_from_pred( prediction["response"] ) prediction['time'] = end_time - start_time return prediction def predict_with_api(self, prompt, model_name): # -> dict[str, Any | float]: response = litellm.completion( model=model_name, messages=[{"role": "user", "content": prompt}], num_retries=2, ) response_text = response["choices"][0]["message"]["content"] return { "response": response_text, "cost": response._hidden_params["response_cost"], } @spaces.GPU def predict_with_hf(self, prompt, model_name): # -> dict[str, Any | float]: self._reset_pipeline(model_name) response = self.pipeline([{"role": "user", "content": prompt}])[0][ "generated_text" ][-1]["content"] return {"response": response, "cost": 0.0} def _init_model_prediction(self, model_name): predict_fun = self.predict_with_api if "gpt-3.5" in model_name: model_name = "openai/gpt-3.5-turbo-0125" elif "gpt-4o-mini" in model_name: model_name = "openai/gpt-4o-mini-2024-07-18" elif "o1-mini" in model_name: model_name = "openai/o1-mini-2024-09-12" elif "QwQ" in model_name: model_name = "together_ai/Qwen/QwQ-32B" elif "DeepSeek-R1-Distill-Llama-70B" in model_name: model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B" elif "llama-8" in model_name: model_name = "meta-llama/Meta-Llama-3-8B-Instruct" predict_fun = self.predict_with_hf else: raise ValueError("Model forbidden") return partial(predict_fun, model_name=model_name)