qatch-demo / prediction.py
simone-papicchio's picture
initial commit for prediction
c2e2aa2
raw
history blame
1.33 kB
# https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
if os.environ.get("SPACES_ZERO_GPU") is not None:
import spaces
else:
class spaces:
@staticmethod
def GPU(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
class ModelPrediction:
def __init__(self, model_name):
self.prediction_fun = self._model_prediction(model_name)
def make_prediction(prompt):
pass
def _model_prediction(self, model_name):
predict_fun = predict_with_api
if 'gpt-3.5' in model_name:
model_name = 'openai/gpt-3.5-turbo-0125'
elif 'gpt-4o-mini' in model_name:
model_name = 'openai/gpt-4o-mini-2024-07-18'
elif 'o1-mini' in model_name:
model_name = 'openai/o1-mini-2024-09-12'
elif 'QwQ' in model_name:
model_name = 'together_ai/Qwen/QwQ-32B'
elif 'DeepSeek-R1-Distill-Llama-70B' in model_name:
model_name = 'together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B'
else:
raise ValueError('Model forbidden')
return
def predict_with_api(prompt):
pass
@spaces.GPU
def predict_with_hf(prompt):
pass