import os
import torch
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import SFTTrainer
from transformers import TrainingArguments
from datasets import load_dataset
import gradio as gr
import json
from huggingface_hub import HfApi

max_seq_length = 4096
dtype = None
load_in_4bit = True
hf_token = os.getenv("HF_TOKEN")
current_num = os.getenv("NUM")

print(f"stage ${current_num}")

api = HfApi(token=hf_token)
models = f"dad1909/CyberSentinel-{current_num}"

print("Starting model and tokenizer loading...")

# Load the model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=models,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
    token=hf_token
)
print("Model and tokenizer loaded successfully.")

print("Configuring PEFT model...")
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
)
print("PEFT model configured.")

# Updated alpaca_prompt for different types
alpaca_prompt = {
    "learning_from": """Below is a CVE definition.

### CVE definition:
{}

### detail CVE:
{}""",
    "definition": """Below is a definition about software vulnerability. Explain it.

### Definition:
{}

### Explanation:
{}""",
    "code_vulnerability": """Below is a code snippet. Identify the line of code that is vulnerable and describe the type of software vulnerability.

### Code Snippet:
{}

### Vulnerability solution:
{}"""
}

EOS_TOKEN = tokenizer.eos_token

def detect_prompt_type(instruction):
    if instruction.startswith("what is code vulnerable of this code:"):
        return "code_vulnerability"
    elif instruction.startswith("Learning from"):
        return "learning_from"
    elif instruction.startswith("what is"):
        return "definition"
    else:
        return "unknown"

def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    outputs = examples["output"]
    texts = []

    for instruction, output in zip(instructions, outputs):
        prompt_type = detect_prompt_type(instruction)
        if prompt_type in alpaca_prompt:
            prompt = alpaca_prompt[prompt_type].format(instruction, output)
        else:
            prompt = instruction + "\n\n" + output
        text = prompt + EOS_TOKEN
        texts.append(text)

    return {"text": texts}

print("Loading dataset...")
dataset = load_dataset("dad1909/DCSV", split="train")
print("Dataset loaded successfully.")

print("Applying formatting function to the dataset...")
dataset = dataset.map(formatting_prompts_func, batched=True)
print("Formatting function applied.")

print("Initializing trainer...")
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=False,
    args=TrainingArguments(
        per_device_train_batch_size=5,
        gradient_accumulation_steps=5,
        learning_rate=2e-4,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        warmup_steps=5,
        logging_steps=10,
        max_steps=200,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir="outputs"
    ),
)
print("Trainer initialized.")

print("Starting training...")
trainer_stats = trainer.train()
print("Training completed.")

num = int(current_num)
num += 1

uploads_models = f"CyberSentinel-{str(num)}"

print("Saving the trained model...")
model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
print("Model saved successfully.")

print("Pushing the model to the hub...")
model.push_to_hub_merged(
    uploads_models,
    tokenizer,
    save_method="merged_16bit",
    token=hf_token
)
print("Model pushed to hub successfully.")

api.delete_space_variable(repo_id="dad1909/CyberCode", key="NUM")
api.add_space_variable(repo_id="dad1909/CyberCode", key="NUM", value=str(num))