sumesh4C commited on
Commit
601d216
·
verified ·
1 Parent(s): c9878b6

Update tasks/utils/predict.py

Browse files
Files changed (1) hide show
  1. tasks/utils/predict.py +6 -1
tasks/utils/predict.py CHANGED
@@ -8,7 +8,7 @@ from tasks.utils.preprocessing import process_text
8
  import json
9
  from sklearn.feature_extraction.text import TfidfVectorizer
10
 
11
- def predict(input_df: pd.DataFrame, tfidf_path:str , model_path: str):
12
  """
13
  Predict the output using a saved TF-IDF vectorizer and Random Forest model.
14
 
@@ -31,8 +31,13 @@ def predict(input_df: pd.DataFrame, tfidf_path:str , model_path: str):
31
  with open(model_path, "rb") as model_file:
32
  model = pickle.load(model_file)
33
 
 
 
 
 
34
  tfidf_vectorizer = TfidfVectorizer(**params)
35
  tfidf_vectorizer.set_params(preprocessor=process_text)
 
36
 
37
  # Transform the input text using the TF-IDF vectorizer
38
  text_data = input_df.to_pandas()["quote"]
 
8
  import json
9
  from sklearn.feature_extraction.text import TfidfVectorizer
10
 
11
+ def predict(input_df: pd.DataFrame, tfidf_path:str , tfidf_voc_path:str, model_path: str):
12
  """
13
  Predict the output using a saved TF-IDF vectorizer and Random Forest model.
14
 
 
31
  with open(model_path, "rb") as model_file:
32
  model = pickle.load(model_file)
33
 
34
+ # Load vocabulary
35
+ with open(tfidf_voc_path, "rb") as f:
36
+ vocab = pickle.load(f)
37
+
38
  tfidf_vectorizer = TfidfVectorizer(**params)
39
  tfidf_vectorizer.set_params(preprocessor=process_text)
40
+ tfidf_vectorizer.set_params(vocabulary=vocab)
41
 
42
  # Transform the input text using the TF-IDF vectorizer
43
  text_data = input_df.to_pandas()["quote"]