Commit
·
6e95f91
1
Parent(s):
293022c
Collaborative filtering completed
Browse files
anime_recommender/entity/artifact_entity.py
CHANGED
@@ -10,9 +10,12 @@ class DataIngestionArtifact:
|
|
10 |
class DataTransformationArtifact:
|
11 |
merged_file_path:str
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
16 |
@dataclass
|
17 |
class ContentBasedModelArtifact:
|
18 |
cosine_similarity_model_file_path:str
|
|
|
10 |
class DataTransformationArtifact:
|
11 |
merged_file_path:str
|
12 |
|
13 |
+
@dataclass
|
14 |
+
class CollaborativeModelArtifact:
|
15 |
+
svd_file_path: Optional[str] = None
|
16 |
+
item_based_knn_file_path: Optional[str] = None
|
17 |
+
user_based_knn_file_path: Optional[str] = None
|
18 |
+
|
19 |
@dataclass
|
20 |
class ContentBasedModelArtifact:
|
21 |
cosine_similarity_model_file_path:str
|
anime_recommender/entity/config_entity.py
CHANGED
@@ -41,6 +41,18 @@ class DataTransformationConfig:
|
|
41 |
self.data_transformation_dir:str = os.path.join(training_pipeline_config.artifact_dir,DATA_TRANSFORMATION_DIR)
|
42 |
self.merged_file_path:str = os.path.join(self.data_transformation_dir,DATA_TRANSFORMATION_TRANSFORMED_DATA_DIR,MERGED_FILE_NAME)
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
class ContentBasedModelConfig:
|
46 |
"""
|
|
|
41 |
self.data_transformation_dir:str = os.path.join(training_pipeline_config.artifact_dir,DATA_TRANSFORMATION_DIR)
|
42 |
self.merged_file_path:str = os.path.join(self.data_transformation_dir,DATA_TRANSFORMATION_TRANSFORMED_DATA_DIR,MERGED_FILE_NAME)
|
43 |
|
44 |
+
class CollaborativeModelConfig:
|
45 |
+
"""
|
46 |
+
Configuration for model training, including paths for trained models.
|
47 |
+
"""
|
48 |
+
def __init__(self,training_pipeline_config:TrainingPipelineConfig):
|
49 |
+
"""
|
50 |
+
Initialize model trainer paths.
|
51 |
+
"""
|
52 |
+
self.model_trainer_dir:str = os.path.join(training_pipeline_config.artifact_dir,MODEL_TRAINER_DIR_NAME)
|
53 |
+
self.svd_trained_model_file_path:str = os.path.join(self.model_trainer_dir,MODEL_TRAINER_COL_TRAINED_MODEL_DIR,MODEL_TRAINER_SVD_TRAINED_MODEL_NAME)
|
54 |
+
self.user_knn_trained_model_file_path:str = os.path.join(self.model_trainer_dir,MODEL_TRAINER_COL_TRAINED_MODEL_DIR,MODEL_TRAINER_USER_KNN_TRAINED_MODEL_NAME)
|
55 |
+
self.item_knn_trained_model_file_path:str = os.path.join(self.model_trainer_dir,MODEL_TRAINER_COL_TRAINED_MODEL_DIR,MODEL_TRAINER_ITEM_KNN_TRAINED_MODEL_NAME)
|
56 |
|
57 |
class ContentBasedModelConfig:
|
58 |
"""
|
anime_recommender/source/collaborative_recommenders.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from anime_recommender.loggers.logging import logging
|
3 |
+
from anime_recommender.exception.exception import AnimeRecommendorException
|
4 |
+
from anime_recommender.entity.config_entity import CollaborativeModelConfig
|
5 |
+
from anime_recommender.entity.artifact_entity import DataTransformationArtifact, CollaborativeModelArtifact
|
6 |
+
from anime_recommender.utils.main_utils.utils import load_csv_data, save_model, load_object
|
7 |
+
from anime_recommender.model_trainer.collaborative_filtering import CollaborativeAnimeRecommender
|
8 |
+
|
9 |
+
class CollaborativeModelTrainer:
|
10 |
+
"""
|
11 |
+
Class to train the model, track metrics, and save the trained model.
|
12 |
+
"""
|
13 |
+
def __init__(self, collaborative_model_trainer_config: CollaborativeModelConfig, data_transformation_artifact: DataTransformationArtifact):
|
14 |
+
try:
|
15 |
+
self.collaborative_model_trainer_config = collaborative_model_trainer_config
|
16 |
+
self.data_transformation_artifact = data_transformation_artifact
|
17 |
+
except Exception as e:
|
18 |
+
raise AnimeRecommendorException(e, sys)
|
19 |
+
|
20 |
+
def initiate_model_trainer(self, model_type: str) -> CollaborativeModelArtifact:
|
21 |
+
try:
|
22 |
+
logging.info("Loading transformed data...")
|
23 |
+
df = load_csv_data(self.data_transformation_artifact.merged_file_path)
|
24 |
+
recommender = CollaborativeAnimeRecommender(df)
|
25 |
+
# recommender.print_unique_user_ids()
|
26 |
+
if model_type == 'svd':
|
27 |
+
logging.info("Training and saving SVD model...")
|
28 |
+
recommender.train_svd()
|
29 |
+
save_model(recommender.svd, self.collaborative_model_trainer_config.svd_trained_model_file_path)
|
30 |
+
|
31 |
+
logging.info("Loading pre-trained SVD model...")
|
32 |
+
svd_model = load_object(self.collaborative_model_trainer_config.svd_trained_model_file_path)
|
33 |
+
svd_recommendations = recommender.get_svd_recommendations(user_id=436, n=10, svd_model=svd_model)
|
34 |
+
logging.info(f"SVD recommendations: {svd_recommendations}")
|
35 |
+
return CollaborativeModelArtifact(
|
36 |
+
svd_file_path=self.collaborative_model_trainer_config.svd_trained_model_file_path
|
37 |
+
)
|
38 |
+
|
39 |
+
elif model_type == 'item_knn':
|
40 |
+
logging.info("Training and saving KNN item-based model...")
|
41 |
+
recommender.train_knn_item_based()
|
42 |
+
save_model(recommender.knn_item_based, self.collaborative_model_trainer_config.item_knn_trained_model_file_path)
|
43 |
+
|
44 |
+
logging.info("Loading pre-trained item-based KNN model...")
|
45 |
+
item_knn_model = load_object(self.collaborative_model_trainer_config.item_knn_trained_model_file_path)
|
46 |
+
item_based_recommendations = recommender.get_item_based_recommendations(
|
47 |
+
anime_name='One Piece', n_recommendations=10, knn_item_model=item_knn_model
|
48 |
+
)
|
49 |
+
logging.info(f"Item Based recommendations: {item_based_recommendations}")
|
50 |
+
return CollaborativeModelArtifact(
|
51 |
+
item_based_knn_file_path=self.collaborative_model_trainer_config.item_knn_trained_model_file_path
|
52 |
+
)
|
53 |
+
|
54 |
+
elif model_type == 'user_knn':
|
55 |
+
logging.info("Training and saving KNN user-based model...")
|
56 |
+
recommender.train_knn_user_based()
|
57 |
+
save_model(recommender.knn_user_based, self.collaborative_model_trainer_config.user_knn_trained_model_file_path)
|
58 |
+
|
59 |
+
logging.info("Loading pre-trained user-based KNN model...")
|
60 |
+
user_knn_model = load_object(self.collaborative_model_trainer_config.user_knn_trained_model_file_path)
|
61 |
+
user_based_recommendations = recommender.get_user_based_recommendations(
|
62 |
+
user_id=817, n_recommendations=10, knn_user_model=user_knn_model
|
63 |
+
)
|
64 |
+
logging.info(f"User Based recommendations: {user_based_recommendations}")
|
65 |
+
return CollaborativeModelArtifact(
|
66 |
+
user_based_knn_file_path=self.collaborative_model_trainer_config.user_knn_trained_model_file_path
|
67 |
+
)
|
68 |
+
|
69 |
+
else:
|
70 |
+
raise ValueError("Invalid model_type. Choose from 'svd', 'item_knn', or 'user_knn'.")
|
71 |
+
|
72 |
+
except Exception as e:
|
73 |
+
raise AnimeRecommendorException(f"Error in CollaborativeModelTrainer: {str(e)}", sys)
|
app.py
CHANGED
@@ -40,10 +40,6 @@ item_based_knn_model_path = hf_hub_download(repo_name, "itembasedknn.pkl")
|
|
40 |
user_based_knn_model_path = hf_hub_download(repo_name, "userbasedknn.pkl")
|
41 |
svd_model_path = hf_hub_download(repo_name, "svd.pkl")
|
42 |
|
43 |
-
# # Load the models into memory
|
44 |
-
# with open(cosine_similarity_model_path, "rb") as f:
|
45 |
-
# cosine_similarity_model = joblib.load(f)
|
46 |
-
|
47 |
with open(item_based_knn_model_path, "rb") as f:
|
48 |
item_based_knn_model = joblib.load(f)
|
49 |
|
|
|
40 |
user_based_knn_model_path = hf_hub_download(repo_name, "userbasedknn.pkl")
|
41 |
svd_model_path = hf_hub_download(repo_name, "svd.pkl")
|
42 |
|
|
|
|
|
|
|
|
|
43 |
with open(item_based_knn_model_path, "rb") as f:
|
44 |
item_based_knn_model = joblib.load(f)
|
45 |
|
run_pipeline.py
CHANGED
@@ -2,10 +2,10 @@ import sys
|
|
2 |
from anime_recommender.loggers.logging import logging
|
3 |
from anime_recommender.exception.exception import AnimeRecommendorException
|
4 |
from anime_recommender.source.data_ingestion import DataIngestion
|
5 |
-
from anime_recommender.entity.config_entity import TrainingPipelineConfig,DataIngestionConfig,DataTransformationConfig,ContentBasedModelConfig
|
6 |
-
# ,DataTransformationConfig
|
7 |
from anime_recommender.source.data_transformation import DataTransformation
|
8 |
-
|
9 |
from anime_recommender.source.content_based_recommender import ContentBasedModelTrainer
|
10 |
# from anime_recommender.source.popularity_based_recommenders import PopularityBasedRecommendor
|
11 |
|
@@ -27,21 +27,21 @@ if __name__ == "__main__":
|
|
27 |
logging.info("Data Transformation Completed.")
|
28 |
print(data_transformation_artifact)
|
29 |
|
30 |
-
#
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
|
38 |
# Content Based Model Training
|
39 |
-
content_based_model_trainer_config = ContentBasedModelConfig(training_pipeline_config)
|
40 |
-
content_based_model_trainer = ContentBasedModelTrainer(content_based_model_trainer_config=content_based_model_trainer_config,data_ingestion_artifact=data_ingestion_artifact)
|
41 |
-
logging.info("Initiating Content Based Model training.")
|
42 |
-
content_based_model_trainer_artifact = content_based_model_trainer.initiate_model_trainer()
|
43 |
-
logging.info("Content Based Model training completed.")
|
44 |
-
print(content_based_model_trainer_artifact)
|
45 |
|
46 |
# # Popularity Based Filtering
|
47 |
# logging.info("Initiating Popularity based filtering.")
|
|
|
2 |
from anime_recommender.loggers.logging import logging
|
3 |
from anime_recommender.exception.exception import AnimeRecommendorException
|
4 |
from anime_recommender.source.data_ingestion import DataIngestion
|
5 |
+
from anime_recommender.entity.config_entity import TrainingPipelineConfig,DataIngestionConfig,DataTransformationConfig,CollaborativeModelConfig,ContentBasedModelConfig
|
6 |
+
# ,DataTransformationConfig
|
7 |
from anime_recommender.source.data_transformation import DataTransformation
|
8 |
+
from anime_recommender.source.collaborative_recommenders import CollaborativeModelTrainer
|
9 |
from anime_recommender.source.content_based_recommender import ContentBasedModelTrainer
|
10 |
# from anime_recommender.source.popularity_based_recommenders import PopularityBasedRecommendor
|
11 |
|
|
|
27 |
logging.info("Data Transformation Completed.")
|
28 |
print(data_transformation_artifact)
|
29 |
|
30 |
+
# Collaborative Model Training
|
31 |
+
collaborative_model_trainer_config = CollaborativeModelConfig(training_pipeline_config)
|
32 |
+
collaborative_model_trainer = CollaborativeModelTrainer(collaborative_model_trainer_config= collaborative_model_trainer_config,data_transformation_artifact=data_transformation_artifact)
|
33 |
+
logging.info("Initiating Collaborative Model training.")
|
34 |
+
collaborative_model_trainer_artifact = collaborative_model_trainer.initiate_model_trainer(model_type='user_knn')
|
35 |
+
logging.info("Collaborative Model training completed.")
|
36 |
+
print(collaborative_model_trainer_artifact)
|
37 |
|
38 |
# Content Based Model Training
|
39 |
+
# content_based_model_trainer_config = ContentBasedModelConfig(training_pipeline_config)
|
40 |
+
# content_based_model_trainer = ContentBasedModelTrainer(content_based_model_trainer_config=content_based_model_trainer_config,data_ingestion_artifact=data_ingestion_artifact)
|
41 |
+
# logging.info("Initiating Content Based Model training.")
|
42 |
+
# content_based_model_trainer_artifact = content_based_model_trainer.initiate_model_trainer()
|
43 |
+
# logging.info("Content Based Model training completed.")
|
44 |
+
# print(content_based_model_trainer_artifact)
|
45 |
|
46 |
# # Popularity Based Filtering
|
47 |
# logging.info("Initiating Popularity based filtering.")
|