File size: 4,765 Bytes
d7c9e73
 
 
8277386
d7c9e73
 
c2e2aa2
 
 
 
d7c9e73
c2e2aa2
 
 
 
 
d7c9e73
c2e2aa2
 
d7c9e73
6ce82f5
d7c9e73
 
ffec641
 
c2e2aa2
 
220b4dd
d7c9e73
6ce82f5
 
 
 
 
d7c9e73
 
6ce82f5
d7c9e73
 
 
 
6ce82f5
 
 
 
 
 
 
d7c9e73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffec641
6ce82f5
d7c9e73
 
 
 
 
 
8277386
d7c9e73
6ce82f5
 
 
8277386
d7c9e73
8277386
d7c9e73
 
 
8277386
 
d7c9e73
c2e2aa2
6ce82f5
d7c9e73
 
 
 
 
 
ffec641
 
 
 
 
c2e2aa2
d7c9e73
 
 
 
 
 
 
6ce82f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from functools import partial
import os
import re
import time
from xml.parsers.expat import model

# https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
if os.environ.get("SPACES_ZERO_GPU") is not None:
    import spaces
else:

    class spaces:
        @staticmethod
        def GPU(func):
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)

            return wrapper

from transformers import pipeline as hf_pipeline

import litellm

from tqdm import tqdm


class ModelPrediction:
    def __init__(self):
        self.model_name2pred_func = {
            "gpt-3.5": self._init_model_prediction("gpt-3.5"),
            "gpt-4o-mini": self._init_model_prediction("gpt-4o-mini"),
            "o1-mini": self._init_model_prediction("o1-mini"),
            "QwQ": self._init_model_prediction("QwQ"),
            "DeepSeek-R1-Distill-Llama-70B": self._init_model_prediction(
                "DeepSeek-R1-Distill-Llama-70B"
            ),
            "llama-8": self._init_model_prediction("llama-8"),
        }

        self._model_name = None
        self._pipeline = None
        self.base_prompt= (
            "Translate the following question in SQL code to be executed over the database to fetch the answer. Return the sql code in ```sql ```\n"
            " Question\n"
            "{question}\n"
            "Database Schema\n"
            "{db_schema}\n"
        )

    @property
    def pipeline(self):
        if self._pipeline is None:
            self._pipeline = hf_pipeline(
                task="text-generation",
                model=self._model_name,
                device_map="auto",
            )
        return self._pipeline

    def _reset_pipeline(self, model_name):
        if self._model_name != model_name:
            self._model_name = model_name
            self._pipeline = None

    @staticmethod
    def _extract_answer_from_pred(pred: str) -> str:
        # extract with regex everything is between <answer> and </answer>
        matches = re.findall(r"<answer>(.*?)</answer>", pred, re.DOTALL)
        if matches:
            return matches[-1].replace("```", "").replace("sql", "").strip()
        else:
            matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL)
            return matches[-1].strip() if matches else pred


    def make_prediction(self, question, db_schema,  model_name, prompt=None):
        if model_name not in self.model_name2pred_func:
            raise ValueError(
                "Model not supported",
                "supported models are",
                self.model_name2pred_func.keys(),
            )
        

        prompt = prompt or self.base_prompt
        prompt = prompt.format(question=question, db_schema=db_schema)

        start_time = time.time()
        prediction = self.model_name2pred_func[model_name](prompt)
        end_time = time.time()
        prediction["response_parsed"] = self._extract_answer_from_pred(
            prediction["response"]
        )
        prediction['time'] = end_time - start_time

        return prediction

   
    def predict_with_api(self, prompt, model_name):  # -> dict[str, Any | float]:
        response = litellm.completion(
            model=model_name,
            messages=[{"role": "user", "content": prompt}],
            num_retries=2,
        )
        response_text = response["choices"][0]["message"]["content"]
        return {
            "response": response_text,
            "cost": response._hidden_params["response_cost"],
        }

    @spaces.GPU
    def predict_with_hf(self, prompt, model_name):  # -> dict[str, Any | float]:
        self._reset_pipeline(model_name)
        response = self.pipeline([{"role": "user", "content": prompt}])[0][
            "generated_text"
        ][-1]["content"]
        return {"response": response, "cost": 0.0}

    def _init_model_prediction(self, model_name):
        predict_fun = self.predict_with_api
        if "gpt-3.5" in model_name:
            model_name = "openai/gpt-3.5-turbo-0125"
        elif "gpt-4o-mini" in model_name:
            model_name = "openai/gpt-4o-mini-2024-07-18"
        elif "o1-mini" in model_name:
            model_name = "openai/o1-mini-2024-09-12"
        elif "QwQ" in model_name:
            model_name = "together_ai/Qwen/QwQ-32B"
        elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
            model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
        elif "llama-8" in model_name:
            model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
            predict_fun = self.predict_with_hf
        else:
            raise ValueError("Model forbidden")

        return partial(predict_fun, model_name=model_name)