from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import random
import os

from .utils.evaluation import AudioEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info

from dotenv import load_dotenv
load_dotenv()

router = APIRouter()

DESCRIPTION = "Random Baseline"
ROUTE = "/audio"



@router.post(ROUTE, tags=["Audio Task"],
             description=DESCRIPTION)
async def evaluate_audio(request: AudioEvaluationRequest):
    """
    Evaluate audio classification for rainforest sound detection.
    
    Current Model: Random Baseline
    - Makes random predictions from the label space (0-1)
    - Used as a baseline for comparison
    """
    # Get space info
    username, space_url = get_space_info()

    # Define the label mapping
    LABEL_MAPPING = {
        "chainsaw": 0,
        "environment": 1
    }
    # Load and prepare the dataset
    # Because the dataset is gated, we need to use the HF_TOKEN environment variable to authenticate
    dataset = load_dataset(request.dataset_name,token=os.getenv("HF_TOKEN"))
    
    # Split dataset
    train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
    test_dataset = train_test["test"]
    
    # Start tracking emissions
    tracker.start()
    tracker.start_task("inference")
    
    #--------------------------------------------------------------------------------------------
    # YOUR MODEL INFERENCE CODE HERE
    load_pretrained(dannywillowliu/frugal_ai_space)
    losses = []

    total = 0
    
    correct = 0
    
    with torch.no_grad():
    
        for i in range(1000):
    
            inputs = []
    
            outputs = []
    
            for j in range(4):
    
                data = next(generator)
    
                x = data['audio']['array']
    
                y = data['label']
    
                inputs.append(x)
    
                outputs.append(y)
    
            input_values = processor(inputs, return_tensors="pt", padding="longest", sampling_rate=16000).input_values
    
            logits = model(input_values.cuda()).logits
    
            loss = torch.nn.functional.cross_entropy(logits, torch.tensor(outputs).cuda())
    
            
    
            probs = torch.nn.functional.softmax(logits)
    
            chosen = torch.argmax(probs,dim=1)
    # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
    #--------------------------------------------------------------------------------------------   
    
    # Make random predictions (placeholder for actual model inference)
    true_labels = test_dataset["label"]
    predictions = [random.randint(0, 1) for _ in range(len(true_labels))]
    
    #--------------------------------------------------------------------------------------------
    # YOUR MODEL INFERENCE STOPS HERE
    #--------------------------------------------------------------------------------------------   
    
    # Stop tracking emissions
    emissions_data = tracker.stop_task()
    
    # Calculate accuracy
    accuracy = accuracy_score(true_labels, predictions)
    
    # Prepare results dictionary
    results = {
        "username": username,
        "space_url": space_url,
        "submission_timestamp": datetime.now().isoformat(),
        "model_description": DESCRIPTION,
        "accuracy": float(accuracy),
        "energy_consumed_wh": emissions_data.energy_consumed * 1000,
        "emissions_gco2eq": emissions_data.emissions * 1000,
        "emissions_data": clean_emissions_data(emissions_data),
        "api_route": ROUTE,
        "dataset_config": {
            "dataset_name": request.dataset_name,
            "test_size": request.test_size,
            "test_seed": request.test_seed
        }
    }
    
    return results