import os
import torch
from dataclasses import dataclass
from accelerate import PartialState
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format
from kto_dataset_processor import process_feel_dataset, SupportedLanguages
from datetime import datetime
import wandb
from enum import Enum
from typing import Optional
from pathlib import Path


# PEFT library: attach and load adapters
from peft import get_peft_model, PeftModel

####################################
#  CONFIGURATION
####################################


@dataclass
class ScriptArguments:
    """
    Configuration for the script.
    """
    process_dataset_func: callable = process_feel_dataset
    checkpoint_path: str = None
    push_to_hub: bool = True
    language: str = "English"  # Default to English

    def __post_init__(self):
        """Validate the language after initialization"""
        try:
            # This will raise ValueError if language is not in the enum
            SupportedLanguages(self.language)
        except ValueError:
            supported_langs = "\n- ".join([lang.value for lang in SupportedLanguages])
            raise ValueError(
                f"Invalid language: '{self.language}'\n"
                f"Supported languages are:\n- {supported_langs}"
            )

@dataclass
class ModelArguments(ModelConfig):
    """
    Configuration for the model.
    """
    model_name: str = "CohereForAI/aya-expanse-8b"
    use_peft: bool = True
    lora_target_modules: str = "all-linear"
    lora_r: int = 16
    lora_alpha: int = 16
    trust_remote_code: bool = True

@dataclass
class TrainingArguments(KTOConfig):
    """
    Configuration for the KTO trainer.
    """
    output_dir: str = f"kto_{ModelArguments.model_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
    num_train_epochs: int = 1
    per_device_train_batch_size: int = 4
    learning_rate: float = 5e-7
    lr_scheduler_type: str = "cosine"
    gradient_accumulation_steps: int = 1
    logging_steps: int = 10
    eval_steps: int = 500
    warmup_ratio: float = 0.1
    bf16: bool = True
    logging_first_step: bool = True

# Initialize configurations
script_args = ScriptArguments()
training_args = TrainingArguments()
model_args = ModelArguments()

####################################
#  HELPER FUNCTIONS
####################################

def load_model_and_tokenizer(model_args):
    """
    Load the base model and tokenizer from the Hugging Face Hub.
    """
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name,
        trust_remote_code=model_args.trust_remote_code,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name,
        trust_remote_code=model_args.trust_remote_code
    )

    # Set pad token if it is missing
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Setup chat format if not available on the tokenizer
    if not getattr(tokenizer, "chat_template", None):
        model, tokenizer = setup_chat_format(model, tokenizer)

    return model, tokenizer

def get_adapter_path(model_name: str, language: str, timestamp: str = None) -> Path:
    """
    Generate standardized adapter path.
    If timestamp is None, returns the base language directory.
    Otherwise, returns specific adapter version path.

    Format: adapters/{model_name}/{language}/version_{timestamp}
    """
    # Clean model name (remove slashes, etc.)
    clean_model_name = model_name.replace('/', '_')

    base_path = Path("adapters") / clean_model_name / language
    if timestamp:
        return base_path / f"version_{timestamp}"
    return base_path

def load_latest_adapter(model, model_name: str, language: str) -> tuple[PeftModel, str]:
    """
    Load the most recent adapter for given model and language.
    Returns: (loaded_model, timestamp of loaded adapter)
    """
    adapter_base = get_adapter_path(model_name, language)

    if not adapter_base.exists():
        return None, None

    # Get all version directories and sort by timestamp
    versions = sorted(
        [d for d in adapter_base.glob("version_*")],
        key=lambda x: x.name,
        reverse=True
    )

    if not versions:
        return None, None

    latest_version = versions[0]
    timestamp = latest_version.name.replace("version_", "")

    model = PeftModel.from_pretrained(model, latest_version, is_trainable=True)
    return model, timestamp

####################################
#  MAIN LOGIC
####################################

def main():
    # Initialize wandb for logging
    wandb.init(project="kto")

    # Get timestamp at start of training
    training_timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    print("Loading base model and tokenizer...")
    model, tokenizer = load_model_and_tokenizer(model_args)
    ref_model, _ = load_model_and_tokenizer(model_args)
    print("Models and tokenizer loaded.")

    # Load existing adapter or create new one
    loaded_model, previous_timestamp = load_latest_adapter(
        model,
        model_args.model_name,
        script_args.language
    )

    if loaded_model is not None:
        model = loaded_model
        print(f"Loaded existing adapter trained at {previous_timestamp}")
    else:
        # Initialize new LoRA adapter
        peft_config = get_peft_config(model_args)
        model = get_peft_model(model, peft_config)
        print("Initialized new adapter")

    # -----------------------------
    # Data Preparation and Training
    # -----------------------------
    print("Processing dataset...")
    dataset = script_args.process_dataset_func(script_args.language)
    print("Dataset processed.")

    print("Initializing trainer...")
    trainer = KTOTrainer(
        model=model,
        ref_model=ref_model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        processing_class=tokenizer,
        peft_config=peft_config,
    )

    # Training
    print("Starting training...")
    trainer.train()
    print("Training completed.")

    # Evaluation
    print("Evaluating model...")
    metrics = trainer.evaluate()
    print(f"Metrics: {metrics}")
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)

    # Log metrics to wandb
    wandb.log({
        "epoch": metrics.get("epoch"),
        "grad_norm": metrics.get("grad_norm"),
        "kl": metrics.get("kl"),
        "learning_rate": metrics.get("learning_rate"),
        "logits/chosen": metrics.get("logits/chosen"),
        "logits/rejected": metrics.get("logits/rejected"),
        "logps/chosen": metrics.get("logps/chosen"),
        "logps/rejected": metrics.get("logps/rejected"),
        "loss": metrics.get("loss"),
        "rewards/chosen": metrics.get("rewards/chosen"),
        "rewards/margins": metrics.get("rewards/margins"),
        "rewards/rejected": metrics.get("rewards/rejected"),
        "step": metrics.get("step")
    })

    # Save the adapter
    adapter_path = get_adapter_path(
        model_args.model_name,
        script_args.language,
        training_timestamp
    )
    adapter_path.parent.mkdir(parents=True, exist_ok=True)

    print(f"Saving adapter to: {adapter_path}")
    model.save_pretrained(adapter_path)

    # Save metadata
    metadata = AdapterMetadata(
        training_timestamp=training_timestamp,
        model_name=model_args.model_name,
        language=script_args.language,
    )
    metadata.save(adapter_path / "metadata.json")

    if script_args.push_to_hub:
        repo_id = f"feel-fl/adapters/{model_args.model_name.replace('/', '_')}/{script_args.language}"
        print(f"Pushing adapter to Hugging Face Hub at {repo_id}...")
        model.push_to_hub(repo_id=repo_id)

    print("Process completed.")

    # Finish wandb run
    wandb.finish()

if __name__ == "__main__":
    main()