File size: 5,698 Bytes
b4f6ffc c3c7748 b4f6ffc c3c7748 b4f6ffc c3c7748 b4f6ffc c3c7748 b4f6ffc c3c7748 b4f6ffc c3c7748 b4f6ffc c3c7748 b4f6ffc c3c7748 b4f6ffc c3c7748 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import sys
from anime_recommender.loggers.logging import logging
from anime_recommender.exception.exception import AnimeRecommendorException
from anime_recommender.entity.config_entity import CollaborativeModelConfig
from anime_recommender.entity.artifact_entity import DataTransformationArtifact, CollaborativeModelArtifact
from anime_recommender.utils.main_utils.utils import load_csv_data, save_model, load_object, upload_model_to_huggingface
from anime_recommender.model_trainer.collaborative_modelling import CollaborativeAnimeRecommender
from anime_recommender.constant import *
class CollaborativeModelTrainer:
"""
Trains and saves collaborative filtering recommendation models.
This class supports three types of models:
- Singular Value Decomposition (SVD)
- Item-based K-Nearest Neighbors (KNN)
- User-based K-Nearest Neighbors (KNN)
"""
def __init__(self, collaborative_model_trainer_config: CollaborativeModelConfig, data_transformation_artifact: DataTransformationArtifact):
"""
Initializes the CollaborativeModelTrainer with configuration and transformed data.
Args:
collaborative_model_trainer_config (CollaborativeModelConfig): Configuration settings for model training.
data_transformation_artifact (DataTransformationArtifact): Data artifact containing the preprocessed dataset path.
"""
try:
self.collaborative_model_trainer_config = collaborative_model_trainer_config
self.data_transformation_artifact = data_transformation_artifact
except Exception as e:
raise AnimeRecommendorException(e, sys)
def initiate_model_trainer(self) -> CollaborativeModelArtifact:
"""
Trains and saves all collaborative filtering models.
Returns:
CollaborativeModelArtifact: Object containing file paths of all trained models.
"""
try:
logging.info("Loading transformed data...")
df = load_csv_data(self.data_transformation_artifact.merged_file_path)
recommender = CollaborativeAnimeRecommender(df)
# Train and save SVD model
logging.info("Training and saving SVD model...")
recommender.train_svd()
save_model(model=recommender.svd,file_path= self.collaborative_model_trainer_config.svd_trained_model_file_path)
upload_model_to_huggingface(
model_path=self.collaborative_model_trainer_config.svd_trained_model_file_path,
repo_id=MODELS_FILEPATH,
filename=MODEL_TRAINER_SVD_TRAINED_MODEL_NAME
)
logging.info("Loading pre-trained SVD model...")
svd_model = load_object(self.collaborative_model_trainer_config.svd_trained_model_file_path)
svd_recommendations = recommender.get_svd_recommendations(user_id=436, n=10, svd_model=svd_model)
logging.info(f"SVD recommendations: {svd_recommendations}")
# Train and save Item-Based KNN model
logging.info("Training and saving KNN item-based model...")
recommender.train_knn_item_based()
save_model(model=recommender.knn_item_based, file_path=self.collaborative_model_trainer_config.item_knn_trained_model_file_path)
upload_model_to_huggingface(
model_path=self.collaborative_model_trainer_config.item_knn_trained_model_file_path,
repo_id=MODELS_FILEPATH,
filename=MODEL_TRAINER_ITEM_KNN_TRAINED_MODEL_NAME
)
logging.info("Loading pre-trained item-based KNN model...")
item_knn_model = load_object(self.collaborative_model_trainer_config.item_knn_trained_model_file_path)
item_based_recommendations = recommender.get_item_based_recommendations(
anime_name='One Piece', n_recommendations=10, knn_item_model=item_knn_model
)
logging.info(f"Item Based recommendations: {item_based_recommendations}")
# Train and save User-Based KNN model
logging.info("Training and saving KNN user-based model...")
recommender.train_knn_user_based()
save_model(model=recommender.knn_user_based,file_path= self.collaborative_model_trainer_config.user_knn_trained_model_file_path)
upload_model_to_huggingface(
model_path=self.collaborative_model_trainer_config.user_knn_trained_model_file_path,
repo_id=MODELS_FILEPATH,
filename=MODEL_TRAINER_USER_KNN_TRAINED_MODEL_NAME
)
logging.info("Loading pre-trained user-based KNN model...")
user_knn_model = load_object(self.collaborative_model_trainer_config.user_knn_trained_model_file_path)
user_based_recommendations = recommender.get_user_based_recommendations(
user_id=817, n_recommendations=10, knn_user_model=user_knn_model
)
logging.info(f"User Based recommendations: {user_based_recommendations}")
return CollaborativeModelArtifact(
svd_file_path=self.collaborative_model_trainer_config.svd_trained_model_file_path,
item_based_knn_file_path=self.collaborative_model_trainer_config.item_knn_trained_model_file_path,
user_based_knn_file_path=self.collaborative_model_trainer_config.user_knn_trained_model_file_path
)
except Exception as e:
raise AnimeRecommendorException(f"Error in CollaborativeModelTrainer: {str(e)}", sys) |