krishnaveni76 commited on
Commit
6e95f91
·
1 Parent(s): 293022c

Collaborative filtering completed

Browse files
anime_recommender/entity/artifact_entity.py CHANGED
@@ -10,9 +10,12 @@ class DataIngestionArtifact:
10
  class DataTransformationArtifact:
11
  merged_file_path:str
12
 
13
-
14
-
15
-
 
 
 
16
  @dataclass
17
  class ContentBasedModelArtifact:
18
  cosine_similarity_model_file_path:str
 
10
  class DataTransformationArtifact:
11
  merged_file_path:str
12
 
13
+ @dataclass
14
+ class CollaborativeModelArtifact:
15
+ svd_file_path: Optional[str] = None
16
+ item_based_knn_file_path: Optional[str] = None
17
+ user_based_knn_file_path: Optional[str] = None
18
+
19
  @dataclass
20
  class ContentBasedModelArtifact:
21
  cosine_similarity_model_file_path:str
anime_recommender/entity/config_entity.py CHANGED
@@ -41,6 +41,18 @@ class DataTransformationConfig:
41
  self.data_transformation_dir:str = os.path.join(training_pipeline_config.artifact_dir,DATA_TRANSFORMATION_DIR)
42
  self.merged_file_path:str = os.path.join(self.data_transformation_dir,DATA_TRANSFORMATION_TRANSFORMED_DATA_DIR,MERGED_FILE_NAME)
43
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  class ContentBasedModelConfig:
46
  """
 
41
  self.data_transformation_dir:str = os.path.join(training_pipeline_config.artifact_dir,DATA_TRANSFORMATION_DIR)
42
  self.merged_file_path:str = os.path.join(self.data_transformation_dir,DATA_TRANSFORMATION_TRANSFORMED_DATA_DIR,MERGED_FILE_NAME)
43
 
44
+ class CollaborativeModelConfig:
45
+ """
46
+ Configuration for model training, including paths for trained models.
47
+ """
48
+ def __init__(self,training_pipeline_config:TrainingPipelineConfig):
49
+ """
50
+ Initialize model trainer paths.
51
+ """
52
+ self.model_trainer_dir:str = os.path.join(training_pipeline_config.artifact_dir,MODEL_TRAINER_DIR_NAME)
53
+ self.svd_trained_model_file_path:str = os.path.join(self.model_trainer_dir,MODEL_TRAINER_COL_TRAINED_MODEL_DIR,MODEL_TRAINER_SVD_TRAINED_MODEL_NAME)
54
+ self.user_knn_trained_model_file_path:str = os.path.join(self.model_trainer_dir,MODEL_TRAINER_COL_TRAINED_MODEL_DIR,MODEL_TRAINER_USER_KNN_TRAINED_MODEL_NAME)
55
+ self.item_knn_trained_model_file_path:str = os.path.join(self.model_trainer_dir,MODEL_TRAINER_COL_TRAINED_MODEL_DIR,MODEL_TRAINER_ITEM_KNN_TRAINED_MODEL_NAME)
56
 
57
  class ContentBasedModelConfig:
58
  """
anime_recommender/source/collaborative_recommenders.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from anime_recommender.loggers.logging import logging
3
+ from anime_recommender.exception.exception import AnimeRecommendorException
4
+ from anime_recommender.entity.config_entity import CollaborativeModelConfig
5
+ from anime_recommender.entity.artifact_entity import DataTransformationArtifact, CollaborativeModelArtifact
6
+ from anime_recommender.utils.main_utils.utils import load_csv_data, save_model, load_object
7
+ from anime_recommender.model_trainer.collaborative_filtering import CollaborativeAnimeRecommender
8
+
9
+ class CollaborativeModelTrainer:
10
+ """
11
+ Class to train the model, track metrics, and save the trained model.
12
+ """
13
+ def __init__(self, collaborative_model_trainer_config: CollaborativeModelConfig, data_transformation_artifact: DataTransformationArtifact):
14
+ try:
15
+ self.collaborative_model_trainer_config = collaborative_model_trainer_config
16
+ self.data_transformation_artifact = data_transformation_artifact
17
+ except Exception as e:
18
+ raise AnimeRecommendorException(e, sys)
19
+
20
+ def initiate_model_trainer(self, model_type: str) -> CollaborativeModelArtifact:
21
+ try:
22
+ logging.info("Loading transformed data...")
23
+ df = load_csv_data(self.data_transformation_artifact.merged_file_path)
24
+ recommender = CollaborativeAnimeRecommender(df)
25
+ # recommender.print_unique_user_ids()
26
+ if model_type == 'svd':
27
+ logging.info("Training and saving SVD model...")
28
+ recommender.train_svd()
29
+ save_model(recommender.svd, self.collaborative_model_trainer_config.svd_trained_model_file_path)
30
+
31
+ logging.info("Loading pre-trained SVD model...")
32
+ svd_model = load_object(self.collaborative_model_trainer_config.svd_trained_model_file_path)
33
+ svd_recommendations = recommender.get_svd_recommendations(user_id=436, n=10, svd_model=svd_model)
34
+ logging.info(f"SVD recommendations: {svd_recommendations}")
35
+ return CollaborativeModelArtifact(
36
+ svd_file_path=self.collaborative_model_trainer_config.svd_trained_model_file_path
37
+ )
38
+
39
+ elif model_type == 'item_knn':
40
+ logging.info("Training and saving KNN item-based model...")
41
+ recommender.train_knn_item_based()
42
+ save_model(recommender.knn_item_based, self.collaborative_model_trainer_config.item_knn_trained_model_file_path)
43
+
44
+ logging.info("Loading pre-trained item-based KNN model...")
45
+ item_knn_model = load_object(self.collaborative_model_trainer_config.item_knn_trained_model_file_path)
46
+ item_based_recommendations = recommender.get_item_based_recommendations(
47
+ anime_name='One Piece', n_recommendations=10, knn_item_model=item_knn_model
48
+ )
49
+ logging.info(f"Item Based recommendations: {item_based_recommendations}")
50
+ return CollaborativeModelArtifact(
51
+ item_based_knn_file_path=self.collaborative_model_trainer_config.item_knn_trained_model_file_path
52
+ )
53
+
54
+ elif model_type == 'user_knn':
55
+ logging.info("Training and saving KNN user-based model...")
56
+ recommender.train_knn_user_based()
57
+ save_model(recommender.knn_user_based, self.collaborative_model_trainer_config.user_knn_trained_model_file_path)
58
+
59
+ logging.info("Loading pre-trained user-based KNN model...")
60
+ user_knn_model = load_object(self.collaborative_model_trainer_config.user_knn_trained_model_file_path)
61
+ user_based_recommendations = recommender.get_user_based_recommendations(
62
+ user_id=817, n_recommendations=10, knn_user_model=user_knn_model
63
+ )
64
+ logging.info(f"User Based recommendations: {user_based_recommendations}")
65
+ return CollaborativeModelArtifact(
66
+ user_based_knn_file_path=self.collaborative_model_trainer_config.user_knn_trained_model_file_path
67
+ )
68
+
69
+ else:
70
+ raise ValueError("Invalid model_type. Choose from 'svd', 'item_knn', or 'user_knn'.")
71
+
72
+ except Exception as e:
73
+ raise AnimeRecommendorException(f"Error in CollaborativeModelTrainer: {str(e)}", sys)
app.py CHANGED
@@ -40,10 +40,6 @@ item_based_knn_model_path = hf_hub_download(repo_name, "itembasedknn.pkl")
40
  user_based_knn_model_path = hf_hub_download(repo_name, "userbasedknn.pkl")
41
  svd_model_path = hf_hub_download(repo_name, "svd.pkl")
42
 
43
- # # Load the models into memory
44
- # with open(cosine_similarity_model_path, "rb") as f:
45
- # cosine_similarity_model = joblib.load(f)
46
-
47
  with open(item_based_knn_model_path, "rb") as f:
48
  item_based_knn_model = joblib.load(f)
49
 
 
40
  user_based_knn_model_path = hf_hub_download(repo_name, "userbasedknn.pkl")
41
  svd_model_path = hf_hub_download(repo_name, "svd.pkl")
42
 
 
 
 
 
43
  with open(item_based_knn_model_path, "rb") as f:
44
  item_based_knn_model = joblib.load(f)
45
 
run_pipeline.py CHANGED
@@ -2,10 +2,10 @@ import sys
2
  from anime_recommender.loggers.logging import logging
3
  from anime_recommender.exception.exception import AnimeRecommendorException
4
  from anime_recommender.source.data_ingestion import DataIngestion
5
- from anime_recommender.entity.config_entity import TrainingPipelineConfig,DataIngestionConfig,DataTransformationConfig,ContentBasedModelConfig
6
- # ,DataTransformationConfig,CollaborativeModelConfig
7
  from anime_recommender.source.data_transformation import DataTransformation
8
- # from anime_recommender.source.collaborative_recommenders import CollaborativeModelTrainer
9
  from anime_recommender.source.content_based_recommender import ContentBasedModelTrainer
10
  # from anime_recommender.source.popularity_based_recommenders import PopularityBasedRecommendor
11
 
@@ -27,21 +27,21 @@ if __name__ == "__main__":
27
  logging.info("Data Transformation Completed.")
28
  print(data_transformation_artifact)
29
 
30
- # # Collaborative Model Training
31
- # collaborative_model_trainer_config = CollaborativeModelConfig(training_pipeline_config)
32
- # collaborative_model_trainer = CollaborativeModelTrainer(collaborative_model_trainer_config= collaborative_model_trainer_config,data_transformation_artifact=data_transformation_artifact)
33
- # logging.info("Initiating Collaborative Model training.")
34
- # collaborative_model_trainer_artifact = collaborative_model_trainer.initiate_model_trainer(model_type='svd')
35
- # logging.info("Collaborative Model training completed.")
36
- # print(collaborative_model_trainer_artifact)
37
 
38
  # Content Based Model Training
39
- content_based_model_trainer_config = ContentBasedModelConfig(training_pipeline_config)
40
- content_based_model_trainer = ContentBasedModelTrainer(content_based_model_trainer_config=content_based_model_trainer_config,data_ingestion_artifact=data_ingestion_artifact)
41
- logging.info("Initiating Content Based Model training.")
42
- content_based_model_trainer_artifact = content_based_model_trainer.initiate_model_trainer()
43
- logging.info("Content Based Model training completed.")
44
- print(content_based_model_trainer_artifact)
45
 
46
  # # Popularity Based Filtering
47
  # logging.info("Initiating Popularity based filtering.")
 
2
  from anime_recommender.loggers.logging import logging
3
  from anime_recommender.exception.exception import AnimeRecommendorException
4
  from anime_recommender.source.data_ingestion import DataIngestion
5
+ from anime_recommender.entity.config_entity import TrainingPipelineConfig,DataIngestionConfig,DataTransformationConfig,CollaborativeModelConfig,ContentBasedModelConfig
6
+ # ,DataTransformationConfig
7
  from anime_recommender.source.data_transformation import DataTransformation
8
+ from anime_recommender.source.collaborative_recommenders import CollaborativeModelTrainer
9
  from anime_recommender.source.content_based_recommender import ContentBasedModelTrainer
10
  # from anime_recommender.source.popularity_based_recommenders import PopularityBasedRecommendor
11
 
 
27
  logging.info("Data Transformation Completed.")
28
  print(data_transformation_artifact)
29
 
30
+ # Collaborative Model Training
31
+ collaborative_model_trainer_config = CollaborativeModelConfig(training_pipeline_config)
32
+ collaborative_model_trainer = CollaborativeModelTrainer(collaborative_model_trainer_config= collaborative_model_trainer_config,data_transformation_artifact=data_transformation_artifact)
33
+ logging.info("Initiating Collaborative Model training.")
34
+ collaborative_model_trainer_artifact = collaborative_model_trainer.initiate_model_trainer(model_type='user_knn')
35
+ logging.info("Collaborative Model training completed.")
36
+ print(collaborative_model_trainer_artifact)
37
 
38
  # Content Based Model Training
39
+ # content_based_model_trainer_config = ContentBasedModelConfig(training_pipeline_config)
40
+ # content_based_model_trainer = ContentBasedModelTrainer(content_based_model_trainer_config=content_based_model_trainer_config,data_ingestion_artifact=data_ingestion_artifact)
41
+ # logging.info("Initiating Content Based Model training.")
42
+ # content_based_model_trainer_artifact = content_based_model_trainer.initiate_model_trainer()
43
+ # logging.info("Content Based Model training completed.")
44
+ # print(content_based_model_trainer_artifact)
45
 
46
  # # Popularity Based Filtering
47
  # logging.info("Initiating Popularity based filtering.")