from src.model import GraphformerModel
from pathlib import Path
from loguru import logger


class ModelTrainer:
    def __init__(self, df, output_path, epochs=100,test_size=0.3):
        self.df = df
        self.output_path = output_path
        self.epochs = epochs
        self.test_size=test_size
        
        # Create output directory
        Path(self.output_path).mkdir(parents=True, exist_ok=True)
        
        # Initialize the HeteroGraphormerModel
    
        self.model = GraphformerModel(df=self.df, output_path=self.output_path, epochs=self.epochs,test_size=self.test_size)

        
        
        logger.info(f"Initialized ModelTrainer with output_path: {self.output_path} and epochs: {self.epochs}")
        

    def train_and_evaluate(self):
     
        try:
            logger.info("Starting model training and evaluation")
            self.model.run_model()
            logger.info("GraphformerModel training and evaluation completed successfully")
        except Exception as e:
            logger.error(f"Error during GraphformerModel training and evaluation: {e}")
            raise