sumesh4C commited on
Commit
258e407
·
verified ·
1 Parent(s): dbda781

Update tasks/utils/predict.py

Browse files
Files changed (1) hide show
  1. tasks/utils/predict.py +9 -4
tasks/utils/predict.py CHANGED
@@ -5,8 +5,10 @@ import pandas as pd
5
  import sys
6
  sys.path.append(".")
7
  from tasks.utils.preprocessing import process_text
 
 
8
 
9
- def predict(input_df: pd.DataFrame, tfidf_vectorizer , model_path: str):
10
  """
11
  Predict the output using a saved TF-IDF vectorizer and Random Forest model.
12
 
@@ -19,16 +21,19 @@ def predict(input_df: pd.DataFrame, tfidf_vectorizer , model_path: str):
19
  Returns:
20
  pd.Series: Predictions for each row in the input dataframe.
21
  """
22
- """
23
  # Load the TF-IDF vectorizer
24
  with open(tfidf_path, "rb") as tfidf_file:
25
- tfidf_vectorizer = pickle.load(tfidf_file)
26
- """
27
 
28
  # Load the Random Forest model
29
  with open(model_path, "rb") as model_file:
30
  model = pickle.load(model_file)
31
 
 
 
 
32
  # Transform the input text using the TF-IDF vectorizer
33
  text_data = input_df.to_pandas()["quote"]
34
  text_features = tfidf_vectorizer.transform(text_data)
 
5
  import sys
6
  sys.path.append(".")
7
  from tasks.utils.preprocessing import process_text
8
+ import json
9
+ from sklearn.feature_extraction.text import TfidfVectorizer
10
 
11
+ def predict(input_df: pd.DataFrame, tfidf_vectorizer:str , model_path: str):
12
  """
13
  Predict the output using a saved TF-IDF vectorizer and Random Forest model.
14
 
 
21
  Returns:
22
  pd.Series: Predictions for each row in the input dataframe.
23
  """
24
+
25
  # Load the TF-IDF vectorizer
26
  with open(tfidf_path, "rb") as tfidf_file:
27
+ params = json.load(tfidf_file)
28
+
29
 
30
  # Load the Random Forest model
31
  with open(model_path, "rb") as model_file:
32
  model = pickle.load(model_file)
33
 
34
+ tfidf_vectorizer = = TfidfVectorizer(**params)
35
+ tfidf_vectorizer.set_params(preprocessor=process_text)
36
+
37
  # Transform the input text using the TF-IDF vectorizer
38
  text_data = input_df.to_pandas()["quote"]
39
  text_features = tfidf_vectorizer.transform(text_data)