import os
import platform
import sys
import time
import boto3
from botocore.exceptions import NoCredentialsError
import logging

import gradio as gr
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "0"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# device = "cuda"

# has_gpu = torch.cuda.is_available()
# device = "cuda" if has_gpu else "cpu"

# print(f"Python Platform: {platform.platform()}")
# print(f"Python Version: {sys.version}")
# print(f"PyTorch Version: {torch.__version__}")
# print("GPU Availability:", "Available" if has_gpu else "Not Available")
# print(f"Target Device: {device}")

# if has_gpu:
#     print(f"GPU Type: {torch.cuda.get_device_name(0)}")
#     print(f"CUDA Version: {torch.version.cuda}")
# else:
#     print("CUDA is not available.")

def download_xmad_file():
    s3 = boto3.client('s3',
                      aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
                      aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'))
    
    # Create the .codebooks directory if it doesn't exist
    codebooks_dir = '.codebooks'
    os.makedirs(codebooks_dir, exist_ok=True)
    
    temp_file_path = os.path.join(codebooks_dir, 'llama-3-8b-instruct_1bit.xmad')
    
    try:
        # Download the file to the .codebooks directory
        s3.download_file('xmad-quantized-models', 'llama-3-8b-instruct_1bit.xmad', temp_file_path)
        print("Download Successful")

        # Restrict permissions on the .codebooks directory
        os.chmod(codebooks_dir, 0o700)

    except NoCredentialsError:
        print("Credentials not available")

download_xmad_file()


def get_gpu_memory():
    return torch.cuda.memory_allocated() / 1024 / 1024  # Convert to MiB


class TorchTracemalloc:
    def __init__(self):
        self.begin = 0
        self.peak = 0

    def __enter__(self):
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()
        self.begin = get_gpu_memory()
        return self

    def __exit__(self, *exc):
        torch.cuda.synchronize()
        self.peak = torch.cuda.max_memory_allocated() / 1024 / 1024

    def consumed(self):
        return self.peak - self.begin


def load_model_and_tokenizer():
    model_name = "NousResearch/Meta-Llama-3-8B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    special_tokens = {"pad_token": "<PAD>"}
    tokenizer.add_special_tokens(special_tokens)
    config = AutoConfig.from_pretrained(model_name)
    setattr(
        config, "quantizer_path", ".codebooks/llama-3-8b-instruct_1bit.xmad"
    )
    setattr(config, "window_length", 32)
    # model = AutoModelForCausalLM.from_pretrained(
    #     model_name, config=config, torch_dtype=torch.float16
    # ).to(device)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, config=config, torch_dtype=torch.float16, device_map="auto"
    )
    
    print(f"Quantizer path in model config: {model.config.quantizer_path}")
    logging.info(f"Quantizer path in model config: {model.config.quantizer_path}")

    if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
        print(
            "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
        )
        model.resize_token_embeddings(len(tokenizer))
    tokenizer.padding_side = "left"
    model.config.pad_token_id = tokenizer.pad_token_id

    return model, tokenizer


model, tokenizer = load_model_and_tokenizer()


def process_dialog(message, history):
    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>"),
    ]

    dialog = [
        {"role": "user" if i % 2 == 0 else "assistant", "content": msg}
        for i, (msg, _) in enumerate(history)
    ]
    dialog.append({"role": "user", "content": message})

    prompt = tokenizer.apply_chat_template(
        dialog, tokenize=False, add_generation_prompt=True
    )
    tokenized_input_prompt_ids = tokenizer(
        prompt, return_tensors="pt"
    ).input_ids.to(model.device)

    start_time = time.time()

    with TorchTracemalloc() as tracemalloc:
        with torch.no_grad():
            output = model.generate(
                tokenized_input_prompt_ids,
                # max_new_tokens=512,
                temperature=0.4,
                do_sample=True,
                eos_token_id=terminators,
                pad_token_id=tokenizer.pad_token_id,
            )

    end_time = time.time()

    response = output[0][tokenized_input_prompt_ids.shape[-1] :]
    cleaned_response = tokenizer.decode(
        response,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
    )

    generation_time = end_time - start_time
    gpu_memory = tracemalloc.consumed()

    return cleaned_response, generation_time, gpu_memory

def chatbot_response(message, history):
    response, generation_time, gpu_memory = process_dialog(message, history)

    metrics = f"\n\n---\n\n **Metrics**\t*Answer Generation Time:* `{generation_time:.2f} sec`\t*GPU Memory Consumption:* `{gpu_memory:.2f} MiB`\n\n"
    return response + metrics


demo = gr.ChatInterface(
    fn=chatbot_response,
    examples=["Hello", "How are you?", "Tell me a joke"],
    title="Chat with xMAD's: 1-bit-Llama-3-8B-Instruct Model",
    description="Contact support@xmad.ai to set up a demo",
)

if __name__ == "__main__":
    username = os.getenv("AUTH_USERNAME")
    password = os.getenv("AUTH_PASSWORD")
    demo.launch(auth=(username, password))
    # demo.launch()