import os
import subprocess
import streamlit as st
from huggingface_hub import snapshot_download, login

if "quantized_model_path" not in st.session_state:
    st.session_state.quantized_model_path = None
if "upload_to_hf" not in st.session_state:
    st.session_state.upload_to_hf = False

def check_directory_path(directory_name: str) -> str:
    if os.path.exists(directory_name):
        path = os.path.abspath(directory_name)
        return str(path)

# Define quantization types
QUANT_TYPES = [
    "Q2_K", "Q3_K_M", "Q3_K_S", "Q4_K_M", "Q4_K_S",
    "Q5_K_M", "Q5_K_S", "Q6_K"
]

model_dir_path = check_directory_path("/app/llama.cpp")

def download_model(hf_model_name, output_dir="/tmp/models"):
    """
    Downloads a Hugging Face model and saves it locally.
    """
    st.write(f"📥 Downloading `{hf_model_name}` from Hugging Face...")
    os.makedirs(output_dir, exist_ok=True)
    snapshot_download(repo_id=hf_model_name, local_dir=output_dir, local_dir_use_symlinks=False)
    st.success("✅ Model downloaded successfully!")

def convert_to_gguf(model_dir, output_file):
    """
    Converts a Hugging Face model to GGUF format.
    """
    st.write(f"🔄 Converting `{model_dir}` to GGUF format...")
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    cmd = [
        "python3", "/app/llama.cpp/convert_hf_to_gguf.py", model_dir,
        "--outtype", "f16", "--outfile", output_file
    ]
    process = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    if process.returncode == 0:
        st.success(f"✅ Conversion complete: `{output_file}`")
    else:
        st.error(f"❌ Conversion failed: {process.stderr}")

def quantize_llama(model_path, quantized_output_path, quant_type):
    """
    Quantizes a GGUF model.
    """
    st.write(f"⚡ Quantizing `{model_path}` with `{quant_type}` precision...")
    os.makedirs(os.path.dirname(quantized_output_path), exist_ok=True)
    quantize_path = "/app/llama.cpp/build/bin/llama-quantize"
    
    cmd = [
        "/app/llama.cpp/build/bin/llama-quantize", 
        model_path, 
        quantized_output_path,
        quant_type
    ]
    
    process = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    if process.returncode == 0:
        st.success(f"✅ Quantized model saved at `{quantized_output_path}`")
    else:
        st.error(f"❌ Quantization failed: {process.stderr}")

def automate_llama_quantization(hf_model_name, quant_type):
    """
    Orchestrates the entire quantization process.
    """
    output_dir = "/tmp/models"
    gguf_file = os.path.join(output_dir, f"{hf_model_name.replace('/', '_')}.gguf")
    quantized_file = gguf_file.replace(".gguf", f"-{quant_type}.gguf")

    progress_bar = st.progress(0)

    # Step 1: Download
    st.write("### Step 1: Downloading Model")
    download_model(hf_model_name, output_dir)
    progress_bar.progress(33)

    # Step 2: Convert to GGUF
    st.write("### Step 2: Converting Model to GGUF Format")
    convert_to_gguf(output_dir, gguf_file)
    progress_bar.progress(66)

    # Step 3: Quantize Model
    st.write("### Step 3: Quantizing Model")
    quantize_llama(gguf_file, quantized_file, quant_type.lower())
    progress_bar.progress(100)

    st.success(f"🎉 All steps completed! Quantized model available at: `{quantized_file}`")
    return quantized_file

def upload_to_huggingface(file_path, repo_id, token):
    """
    Uploads a file to Hugging Face Hub.
    """
    try:
        # Log in to Hugging Face
        login(token=token)

        # Initialize HfApi
        api = HfApi()

        # Create the repository if it doesn't exist
        api.create_repo(repo_id, exist_ok=True, repo_type="model")

        # Upload the file
        api.upload_file(
            path_or_fileobj=file_path,
            path_in_repo=os.path.basename(file_path),
            repo_id=repo_id,
        )
        st.success(f"✅ File uploaded to Hugging Face: {repo_id}")
    except Exception as e:
        st.error(f"❌ Failed to upload file: {e}")

st.title("🦙 LLaMA Model Quantization (llama.cpp)")

hf_model_name = st.text_input("Enter Hugging Face Model Name", "Qwen/Qwen2.5-1.5B")
quant_type = st.selectbox("Select Quantization Type", QUANT_TYPES)
start_button = st.button("🚀 Start Quantization")

if start_button:
    with st.spinner("Processing..."):
        st.session_state.quantized_model_path = automate_llama_quantization(hf_model_name, quant_type)

if st.session_state.quantized_model_path:
    with open(st.session_state.quantized_model_path, "rb") as f:
        st.download_button("⬇️ Download Quantized Model", f, file_name=os.path.basename(st.session_state.quantized_model_path))
    
    # Checkbox for upload section
    st.session_state.upload_to_hf = st.checkbox("Upload to Hugging Face", value=st.session_state.upload_to_hf)
    
    if st.session_state.upload_to_hf:
        st.write("### Upload to Hugging Face")
        repo_id = st.text_input("Enter Hugging Face Repository ID (e.g., 'username/repo-name')")
        hf_token = st.text_input("Enter Hugging Face Token", type="password")
        
        if st.button("📤 Upload to Hugging Face"):
            if repo_id and hf_token:
                with st.spinner("Uploading..."):
                    upload_to_huggingface(st.session_state.quantized_model_path, repo_id, hf_token)
            else:
                st.warning("Please provide a valid repository ID and Hugging Face token.")