simone-papicchio commited on
Commit
220b4dd
·
1 Parent(s): d7c9e73

fix error in prediction.py

Browse files
Files changed (2) hide show
  1. prediction.py +1 -1
  2. test_prediction.py +10 -0
prediction.py CHANGED
@@ -23,7 +23,7 @@ import litellm
23
 
24
 
25
  class ModelPrediction:
26
- def __init__(self, model_name):
27
  self.model_name2pred_func = {
28
  "gpt-3.5": self._model_prediction("gpt-3.5"),
29
  "gpt-4o-mini": self._model_prediction("gpt-4o-mini"),
 
23
 
24
 
25
  class ModelPrediction:
26
+ def __init__(self):
27
  self.model_name2pred_func = {
28
  "gpt-3.5": self._model_prediction("gpt-3.5"),
29
  "gpt-4o-mini": self._model_prediction("gpt-4o-mini"),
test_prediction.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from prediction import ModelPrediction
2
+
3
+
4
+ def main():
5
+ model = ModelPrediction()
6
+ response = model.make_prediction("Hi, how are you?", "gpt-3.5")
7
+ print(response)
8
+
9
+ if __name__ == "__main__":
10
+ main()