|
import sys
|
|
from anime_recommender.loggers.logging import logging
|
|
from anime_recommender.exception.exception import AnimeRecommendorException
|
|
from anime_recommender.entity.config_entity import CollaborativeModelConfig
|
|
from anime_recommender.entity.artifact_entity import DataTransformationArtifact, CollaborativeModelArtifact
|
|
from anime_recommender.utils.main_utils.utils import load_csv_data, save_model, load_object, upload_model_to_huggingface
|
|
from anime_recommender.model_trainer.collaborative_modelling import CollaborativeAnimeRecommender
|
|
from anime_recommender.constant import *
|
|
|
|
class CollaborativeModelTrainer:
|
|
"""
|
|
Trains and saves collaborative filtering recommendation models.
|
|
|
|
This class supports three types of models:
|
|
- Singular Value Decomposition (SVD)
|
|
- Item-based K-Nearest Neighbors (KNN)
|
|
- User-based K-Nearest Neighbors (KNN)
|
|
"""
|
|
def __init__(self, collaborative_model_trainer_config: CollaborativeModelConfig, data_transformation_artifact: DataTransformationArtifact):
|
|
"""
|
|
Initializes the CollaborativeModelTrainer with configuration and transformed data.
|
|
|
|
Args:
|
|
collaborative_model_trainer_config (CollaborativeModelConfig): Configuration settings for model training.
|
|
data_transformation_artifact (DataTransformationArtifact): Data artifact containing the preprocessed dataset path.
|
|
"""
|
|
try:
|
|
self.collaborative_model_trainer_config = collaborative_model_trainer_config
|
|
self.data_transformation_artifact = data_transformation_artifact
|
|
except Exception as e:
|
|
raise AnimeRecommendorException(e, sys)
|
|
|
|
def initiate_model_trainer(self) -> CollaborativeModelArtifact:
|
|
"""
|
|
Trains and saves all collaborative filtering models.
|
|
|
|
Returns:
|
|
CollaborativeModelArtifact: Object containing file paths of all trained models.
|
|
"""
|
|
try:
|
|
logging.info("Loading transformed data...")
|
|
df = load_csv_data(self.data_transformation_artifact.merged_file_path)
|
|
recommender = CollaborativeAnimeRecommender(df)
|
|
|
|
|
|
logging.info("Training and saving SVD model...")
|
|
recommender.train_svd()
|
|
save_model(model=recommender.svd,file_path= self.collaborative_model_trainer_config.svd_trained_model_file_path)
|
|
upload_model_to_huggingface(
|
|
model_path=self.collaborative_model_trainer_config.svd_trained_model_file_path,
|
|
repo_id=MODELS_FILEPATH,
|
|
filename=MODEL_TRAINER_SVD_TRAINED_MODEL_NAME
|
|
)
|
|
logging.info("Loading pre-trained SVD model...")
|
|
svd_model = load_object(self.collaborative_model_trainer_config.svd_trained_model_file_path)
|
|
svd_recommendations = recommender.get_svd_recommendations(user_id=436, n=10, svd_model=svd_model)
|
|
logging.info(f"SVD recommendations: {svd_recommendations}")
|
|
|
|
|
|
logging.info("Training and saving KNN item-based model...")
|
|
recommender.train_knn_item_based()
|
|
save_model(model=recommender.knn_item_based, file_path=self.collaborative_model_trainer_config.item_knn_trained_model_file_path)
|
|
upload_model_to_huggingface(
|
|
model_path=self.collaborative_model_trainer_config.item_knn_trained_model_file_path,
|
|
repo_id=MODELS_FILEPATH,
|
|
filename=MODEL_TRAINER_ITEM_KNN_TRAINED_MODEL_NAME
|
|
)
|
|
logging.info("Loading pre-trained item-based KNN model...")
|
|
item_knn_model = load_object(self.collaborative_model_trainer_config.item_knn_trained_model_file_path)
|
|
item_based_recommendations = recommender.get_item_based_recommendations(
|
|
anime_name='One Piece', n_recommendations=10, knn_item_model=item_knn_model
|
|
)
|
|
logging.info(f"Item Based recommendations: {item_based_recommendations}")
|
|
|
|
|
|
logging.info("Training and saving KNN user-based model...")
|
|
recommender.train_knn_user_based()
|
|
save_model(model=recommender.knn_user_based,file_path= self.collaborative_model_trainer_config.user_knn_trained_model_file_path)
|
|
upload_model_to_huggingface(
|
|
model_path=self.collaborative_model_trainer_config.user_knn_trained_model_file_path,
|
|
repo_id=MODELS_FILEPATH,
|
|
filename=MODEL_TRAINER_USER_KNN_TRAINED_MODEL_NAME
|
|
)
|
|
logging.info("Loading pre-trained user-based KNN model...")
|
|
user_knn_model = load_object(self.collaborative_model_trainer_config.user_knn_trained_model_file_path)
|
|
user_based_recommendations = recommender.get_user_based_recommendations(
|
|
user_id=817, n_recommendations=10, knn_user_model=user_knn_model
|
|
)
|
|
logging.info(f"User Based recommendations: {user_based_recommendations}")
|
|
return CollaborativeModelArtifact(
|
|
svd_file_path=self.collaborative_model_trainer_config.svd_trained_model_file_path,
|
|
item_based_knn_file_path=self.collaborative_model_trainer_config.item_knn_trained_model_file_path,
|
|
user_based_knn_file_path=self.collaborative_model_trainer_config.user_knn_trained_model_file_path
|
|
)
|
|
except Exception as e:
|
|
raise AnimeRecommendorException(f"Error in CollaborativeModelTrainer: {str(e)}", sys) |