Spaces:
Sleeping
Sleeping
Commit
·
c2e2aa2
1
Parent(s):
2ae6c7c
initial commit for prediction
Browse files- app.py +1 -15
- prediction.py +45 -0
- requirements.txt +3 -0
app.py
CHANGED
@@ -1,21 +1,11 @@
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
import os
|
4 |
-
|
5 |
-
# if os.environ.get("SPACES_ZERO_GPU") is not None:
|
6 |
-
# import spaces
|
7 |
-
# else:
|
8 |
-
# class spaces:
|
9 |
-
# @staticmethod
|
10 |
-
# def GPU(func):
|
11 |
-
# def wrapper(*args, **kwargs):
|
12 |
-
# return func(*args, **kwargs)
|
13 |
-
# return wrapper
|
14 |
import sys
|
15 |
from qatch.connectors.sqlite_connector import SqliteConnector
|
16 |
from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
|
17 |
from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
|
18 |
-
#from predictor.orchestrator_predictor import OrchestratorPredictor
|
19 |
import utils_get_db_tables_info
|
20 |
import utilities as us
|
21 |
import time
|
@@ -23,10 +13,6 @@ import plotly.express as px
|
|
23 |
import plotly.graph_objects as go
|
24 |
import plotly.colors as pc
|
25 |
|
26 |
-
# @spaces.GPU
|
27 |
-
# def model_prediction():
|
28 |
-
# pass
|
29 |
-
|
30 |
with open('style.css', 'r') as file:
|
31 |
css = file.read()
|
32 |
|
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
import os
|
4 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import sys
|
6 |
from qatch.connectors.sqlite_connector import SqliteConnector
|
7 |
from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
|
8 |
from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
|
|
|
9 |
import utils_get_db_tables_info
|
10 |
import utilities as us
|
11 |
import time
|
|
|
13 |
import plotly.graph_objects as go
|
14 |
import plotly.colors as pc
|
15 |
|
|
|
|
|
|
|
|
|
16 |
with open('style.css', 'r') as file:
|
17 |
css = file.read()
|
18 |
|
prediction.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
|
2 |
+
if os.environ.get("SPACES_ZERO_GPU") is not None:
|
3 |
+
import spaces
|
4 |
+
else:
|
5 |
+
class spaces:
|
6 |
+
@staticmethod
|
7 |
+
def GPU(func):
|
8 |
+
def wrapper(*args, **kwargs):
|
9 |
+
return func(*args, **kwargs)
|
10 |
+
return wrapper
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
class ModelPrediction:
|
15 |
+
def __init__(self, model_name):
|
16 |
+
self.prediction_fun = self._model_prediction(model_name)
|
17 |
+
|
18 |
+
def make_prediction(prompt):
|
19 |
+
pass
|
20 |
+
|
21 |
+
def _model_prediction(self, model_name):
|
22 |
+
predict_fun = predict_with_api
|
23 |
+
if 'gpt-3.5' in model_name:
|
24 |
+
model_name = 'openai/gpt-3.5-turbo-0125'
|
25 |
+
elif 'gpt-4o-mini' in model_name:
|
26 |
+
model_name = 'openai/gpt-4o-mini-2024-07-18'
|
27 |
+
elif 'o1-mini' in model_name:
|
28 |
+
model_name = 'openai/o1-mini-2024-09-12'
|
29 |
+
elif 'QwQ' in model_name:
|
30 |
+
model_name = 'together_ai/Qwen/QwQ-32B'
|
31 |
+
elif 'DeepSeek-R1-Distill-Llama-70B' in model_name:
|
32 |
+
model_name = 'together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B'
|
33 |
+
else:
|
34 |
+
raise ValueError('Model forbidden')
|
35 |
+
|
36 |
+
return
|
37 |
+
|
38 |
+
def predict_with_api(prompt):
|
39 |
+
pass
|
40 |
+
|
41 |
+
|
42 |
+
@spaces.GPU
|
43 |
+
def predict_with_hf(prompt):
|
44 |
+
pass
|
45 |
+
|
requirements.txt
CHANGED
@@ -7,6 +7,9 @@ langgraph>=0.2.53
|
|
7 |
langchain>=0.3.9
|
8 |
func-timeout>=4.3.5
|
9 |
eval-type-backport>=0.2.0
|
|
|
|
|
|
|
10 |
|
11 |
# Conditional dependency for Gradio (requires Python >=3.10)
|
12 |
gradio>=5.20.1; python_version >= "3.10"
|
|
|
7 |
langchain>=0.3.9
|
8 |
func-timeout>=4.3.5
|
9 |
eval-type-backport>=0.2.0
|
10 |
+
openai==1.66.3
|
11 |
+
litellm==1.63.14
|
12 |
+
together==1.4.6
|
13 |
|
14 |
# Conditional dependency for Gradio (requires Python >=3.10)
|
15 |
gradio>=5.20.1; python_version >= "3.10"
|