File size: 5,698 Bytes
b4f6ffc
 
 
 
 
c3c7748
b4f6ffc
c3c7748
 
b4f6ffc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3c7748
b4f6ffc
c3c7748
 
b4f6ffc
c3c7748
b4f6ffc
 
 
 
c3c7748
b4f6ffc
c3c7748
 
 
 
 
 
 
 
 
 
 
 
 
b4f6ffc
c3c7748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4f6ffc
c3c7748
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import sys
from anime_recommender.loggers.logging import logging
from anime_recommender.exception.exception import AnimeRecommendorException
from anime_recommender.entity.config_entity import CollaborativeModelConfig
from anime_recommender.entity.artifact_entity import DataTransformationArtifact, CollaborativeModelArtifact
from anime_recommender.utils.main_utils.utils import load_csv_data, save_model, load_object, upload_model_to_huggingface
from anime_recommender.model_trainer.collaborative_modelling import CollaborativeAnimeRecommender
from anime_recommender.constant import *
 
class CollaborativeModelTrainer:
    """

    Trains and saves collaborative filtering recommendation models.



    This class supports three types of models:

    - Singular Value Decomposition (SVD)

    - Item-based K-Nearest Neighbors (KNN)

    - User-based K-Nearest Neighbors (KNN)

    """
    def __init__(self, collaborative_model_trainer_config: CollaborativeModelConfig, data_transformation_artifact: DataTransformationArtifact):
        """

        Initializes the CollaborativeModelTrainer with configuration and transformed data.



        Args:

            collaborative_model_trainer_config (CollaborativeModelConfig): Configuration settings for model training.

            data_transformation_artifact (DataTransformationArtifact): Data artifact containing the preprocessed dataset path.

        """
        try:
            self.collaborative_model_trainer_config = collaborative_model_trainer_config
            self.data_transformation_artifact = data_transformation_artifact
        except Exception as e:
            raise AnimeRecommendorException(e, sys)

    def initiate_model_trainer(self) -> CollaborativeModelArtifact:
        """

        Trains and saves all collaborative filtering models.

        

        Returns:

            CollaborativeModelArtifact: Object containing file paths of all trained models. 

        """
        try:
            logging.info("Loading transformed data...")
            df = load_csv_data(self.data_transformation_artifact.merged_file_path)
            recommender = CollaborativeAnimeRecommender(df)

            # Train and save SVD model
            logging.info("Training and saving SVD model...")
            recommender.train_svd()
            save_model(model=recommender.svd,file_path= self.collaborative_model_trainer_config.svd_trained_model_file_path)
            upload_model_to_huggingface(
                model_path=self.collaborative_model_trainer_config.svd_trained_model_file_path,
                repo_id=MODELS_FILEPATH,
                filename=MODEL_TRAINER_SVD_TRAINED_MODEL_NAME
            )
            logging.info("Loading pre-trained SVD model...")
            svd_model = load_object(self.collaborative_model_trainer_config.svd_trained_model_file_path) 
            svd_recommendations = recommender.get_svd_recommendations(user_id=436, n=10, svd_model=svd_model)
            logging.info(f"SVD recommendations: {svd_recommendations}")

            # Train and save Item-Based KNN model
            logging.info("Training and saving KNN item-based model...")
            recommender.train_knn_item_based()
            save_model(model=recommender.knn_item_based, file_path=self.collaborative_model_trainer_config.item_knn_trained_model_file_path)
            upload_model_to_huggingface(
                model_path=self.collaborative_model_trainer_config.item_knn_trained_model_file_path,
                repo_id=MODELS_FILEPATH,
                filename=MODEL_TRAINER_ITEM_KNN_TRAINED_MODEL_NAME
            )
            logging.info("Loading pre-trained item-based KNN model...")
            item_knn_model = load_object(self.collaborative_model_trainer_config.item_knn_trained_model_file_path)
            item_based_recommendations = recommender.get_item_based_recommendations(
                anime_name='One Piece', n_recommendations=10, knn_item_model=item_knn_model
            )
            logging.info(f"Item Based recommendations: {item_based_recommendations}")
        
            # Train and save User-Based KNN model
            logging.info("Training and saving KNN user-based model...")
            recommender.train_knn_user_based()
            save_model(model=recommender.knn_user_based,file_path= self.collaborative_model_trainer_config.user_knn_trained_model_file_path)
            upload_model_to_huggingface(
                model_path=self.collaborative_model_trainer_config.user_knn_trained_model_file_path,
                repo_id=MODELS_FILEPATH,
                filename=MODEL_TRAINER_USER_KNN_TRAINED_MODEL_NAME
            )
            logging.info("Loading pre-trained user-based KNN model...")
            user_knn_model = load_object(self.collaborative_model_trainer_config.user_knn_trained_model_file_path)
            user_based_recommendations = recommender.get_user_based_recommendations(
                user_id=817, n_recommendations=10, knn_user_model=user_knn_model
            )
            logging.info(f"User Based recommendations: {user_based_recommendations}")
            return CollaborativeModelArtifact(
                svd_file_path=self.collaborative_model_trainer_config.svd_trained_model_file_path,
                item_based_knn_file_path=self.collaborative_model_trainer_config.item_knn_trained_model_file_path,
                user_based_knn_file_path=self.collaborative_model_trainer_config.user_knn_trained_model_file_path
            )
        except Exception as e:
            raise AnimeRecommendorException(f"Error in CollaborativeModelTrainer: {str(e)}", sys)