simone-papicchio commited on
Commit
40354df
·
1 Parent(s): 45d1f9d

feat add llama8b zero spaces

Browse files
Files changed (1) hide show
  1. prediction.py +13 -14
prediction.py CHANGED
@@ -24,6 +24,14 @@ import litellm
24
  from tqdm import tqdm
25
 
26
 
 
 
 
 
 
 
 
 
27
  class ModelPrediction:
28
  def __init__(self):
29
  self.model_name2pred_func = {
@@ -47,16 +55,6 @@ class ModelPrediction:
47
  "{db_schema}\n"
48
  )
49
 
50
- @property
51
- def pipeline(self):
52
- if self._pipeline is None:
53
- self._pipeline = hf_pipeline(
54
- task="text-generation",
55
- model=self._model_name,
56
- device_map="auto",
57
- )
58
- return self._pipeline
59
-
60
  def _reset_pipeline(self, model_name):
61
  if self._model_name != model_name:
62
  self._model_name = model_name
@@ -110,10 +108,11 @@ class ModelPrediction:
110
 
111
  @spaces.GPU
112
  def predict_with_hf(self, prompt, model_name): # -> dict[str, Any | float]:
113
- self._reset_pipeline(model_name)
114
- response = self.pipeline([{"role": "user", "content": prompt}])[0][
115
- "generated_text"
116
- ][-1]["content"]
 
117
  return {"response": response, "cost": 0.0}
118
 
119
  def _init_model_prediction(self, model_name):
 
24
  from tqdm import tqdm
25
 
26
 
27
+ pipeline = transformers.pipeline(
28
+ "text-generation",
29
+ model="meta-llama/Meta-Llama-3.1-8B-Instruct",
30
+ model_kwargs={"torch_dtype": torch.bfloat16},
31
+ )
32
+ pipeline.to('cuda')
33
+
34
+
35
  class ModelPrediction:
36
  def __init__(self):
37
  self.model_name2pred_func = {
 
55
  "{db_schema}\n"
56
  )
57
 
 
 
 
 
 
 
 
 
 
 
58
  def _reset_pipeline(self, model_name):
59
  if self._model_name != model_name:
60
  self._model_name = model_name
 
108
 
109
  @spaces.GPU
110
  def predict_with_hf(self, prompt, model_name): # -> dict[str, Any | float]:
111
+ outputs = pipeline(
112
+ [{"role": "user", "content": prompt}],
113
+ max_new_tokens=256,
114
+ )
115
+ response = outputs[0]["generated_text"][-1]
116
  return {"response": response, "cost": 0.0}
117
 
118
  def _init_model_prediction(self, model_name):