import os
from typing import Tuple, List
import gradio as gr
import spaces
from dataclasses import dataclass
from huggingface_hub import HfApi, CommitOperationAdd
from transformers import AutoProcessor
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot, wrap_hf_model_class

@dataclass
class CommitInfo:
    repo_url: str

def parse_ignore_list(ignore_str: str) -> List[str]:
    """Parse comma-separated ignore list string into list"""
    return [item.strip() for item in ignore_str.split(',') if item.strip()]

def create_quantized_model(
    model_id: str,
    work_dir: str,
    ignore_list: List[str],
    model_class_name: str
) -> Tuple[str, List[Tuple[str, Exception]]]:
    """Quantize model to FP8 and save to disk"""

    errors = []
    try:
        # Get the appropriate model class
        exec(f"from transformers import {model_class_name}")
        model_class = eval(model_class_name)
        wrapped_model_class = wrap_hf_model_class(model_class)

        # Load model with ZeroGPU
        model = wrapped_model_class.from_pretrained(
            model_id,
            device_map="auto",
            torch_dtype="auto",
            trust_remote_code=True,
            _attn_implementation="eager"
        )
        processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

        # Configure quantization
        recipe = QuantizationModifier(
            targets="Linear",
            scheme="FP8_DYNAMIC",
            ignore=ignore_list,
        )

        # Apply quantization
        save_dir = os.path.join(work_dir, f"{model_id.split('/')[-1]}-FP8-dynamic")
        oneshot(model=model, recipe=recipe, output_dir=save_dir)
        processor.save_pretrained(save_dir)
        
        return save_dir, errors
        
    except Exception as e:
        errors.append((model_id, e))
        raise e

def push_to_hub(
    api: HfApi,
    model_id: str,
    quantized_path: str,
    token: str,
    ignore_list: List[str],
    model_class_name: str,
) -> CommitInfo:
    """Create new repository with quantized model"""
    
    # Create new model repo name
    original_owner = model_id.split('/')[0]
    new_model_name = f"{model_id.split('/')[-1]}-fp8"
    
    # Get the token owner's username
    token_owner = api.whoami(token)["name"]
    
    # Create the new repo under the token owner's account
    target_repo = f"{token_owner}/{new_model_name}"
    
    # Create model card content
    model_card = f"""---
language:
- en
license: apache-2.0
tags:
- fp8
- quantized
- llmcompressor
base_model: {model_id}
quantization_config:
  ignored_layers: {ignore_list}
  model_class: {model_class_name}
---

# {new_model_name}

This is an FP8-quantized version of [{model_id}](https://huggingface.co/{model_id}) using [LLM Compressor](https://github.com/georgian-io/LLM-Compressor).

## Quantization Details

- Weights quantized to FP8 with per channel PTQ
- Activations quantized to FP8 with dynamic per token
- Linear layers targeted for quantization
- Ignored layers: {ignore_list}
- Model class: {model_class_name}

## Usage

```python
from transformers import {model_class_name}, AutoProcessor

model = {model_class_name}.from_pretrained("{target_repo}")
processor = AutoProcessor.from_pretrained("{target_repo}")
```
"""

    # Create new repository
    api.create_repo(
        repo_id=target_repo,
        private=False,
        exist_ok=True,
    )

    # Prepare operations for upload
    operations = [
        CommitOperationAdd(path_in_repo="README.md", path_or_content=model_card),
    ]
    
    # Add all files from quantized model
    for root, _, files in os.walk(quantized_path):
        for file in files:
            file_path = os.path.join(root, file)
            relative_path = os.path.relpath(file_path, quantized_path)
            operations.append(
                CommitOperationAdd(
                    path_in_repo=relative_path,
                    path_or_content=file_path
                )
            )

    # Upload files
    api.create_commit(
        repo_id=target_repo,
        operations=operations,
        commit_message=f"Add FP8 quantized version of {model_id}",
    )

    return CommitInfo(repo_url=f"https://huggingface.co/{target_repo}")

@spaces.GPU(duration=900)  # 15 minutes timeout for large models
def run(
    model_id: str,
    token: str,
    ignore_str: str,
    model_class_name: str
) -> str:
    """Main function to handle quantization and model upload"""
    
    if not token or model_id == "":
        return """
        ### Invalid input 🐞
        
        Please provide both a token and model_id.
        """
        
    try:
        # Parse ignore list
        ignore_list = parse_ignore_list(ignore_str)
        
        # Set up API with user's token
        api = HfApi(token=token)
        
        print("Processing model:", model_id)
        print("Ignore list:", ignore_list)
        print("Model class:", model_class_name)
        
        # Create working directory
        work_dir = "quantized_models"
        os.makedirs(work_dir, exist_ok=True)
        
        # Quantize model
        quantized_path, errors = create_quantized_model(
            model_id,
            work_dir,
            ignore_list,
            model_class_name
        )
        
        # Upload quantized model to new repository
        commit_info = push_to_hub(
            api,
            model_id,
            quantized_path,
            token,
            ignore_list,
            model_class_name
        )

        response = f"""
        ### Success 🔥

        Your model has been successfully quantized to FP8 and uploaded to a new repository:

        [{commit_info.repo_url}]({commit_info.repo_url})
        
        Configuration:
        - Ignored layers: {ignore_list}
        - Model class: {model_class_name}
        
        You can use this model directly with the transformers library!
        """
        
        if errors:
            response += "\nWarnings during quantization:\n"
            response += "\n".join(f"Warning for {filename}: {e}" for filename, e in errors)
            
        return response
        
    except Exception as e:
        return f"""
        ### Error 😢

        An error occurred during processing:
        {str(e)}
        """

# Gradio Interface
DESCRIPTION = """
# Convert any model to FP8 using LLM Compressor

This space will quantize your model to FP8 format using LLM Compressor and create a new model repository under your account.

The steps are:
1. Paste your HuggingFace token (from hf.co/settings/tokens) - needs write access
2. Enter the model ID you want to quantize
3. (Optional) Customize ignored layers and model class
4. Click "Submit"
5. You'll get a link to your new quantized model repository on your profile! 🚀

## Advanced Options:
- **Ignore List**: Comma-separated list of layer patterns to ignore during quantization. Examples:
  - Llama: `lm_head`
  - Phi3v: `re:.*lm_head,re:model.vision_embed_tokens.*`
  - Llama Vision: `re:.*lm_head,re:multi_modal_projector.*,re:vision_model.*`
- **Model Class**: Specific model class from transformers (default: AutoModelForCausalLM). Examples:
  - `AutoModelForCausalLM`
  - `MllamaForConditionalGeneration`
  - `LlavaForConditionalGeneration`

Note: 
- Processing may take several minutes depending on the model size
- The quantized model will be created as a new public repository under your account
- Your token needs write access to create the new repository
"""

title = "FP8 Quantization with LLM Compressor"

with gr.Blocks(title=title) as demo:
    gr.Markdown(DESCRIPTION)
    
    with gr.Row():
        with gr.Column():
            model_id = gr.Text(
                max_lines=1,
                label="model_id",
                placeholder="huggingface/model-name"
            )
            token = gr.Text(
                max_lines=1,
                label="your_hf_token (requires write access)",
                placeholder="hf_..."
            )
            ignore_str = gr.Text(
                max_lines=1,
                label="ignore_list (comma-separated)",
                placeholder="re:.*lm_head,re:vision_model.*",
                value="re:.*lm_head"
            )
            model_class_name = gr.Text(
                max_lines=1,
                label="model_class_name (optional)",
                placeholder="AutoModelForCausalLM",
                value="AutoModelForCausalLM"
            )
            
            with gr.Row():
                clean = gr.ClearButton()
                submit = gr.Button("Submit", variant="primary")
        
        with gr.Column():
            output = gr.Markdown()
    
    submit.click(
        run,
        inputs=[model_id, token, ignore_str, model_class_name],
        outputs=output,
        concurrency_limit=1
    )

demo.queue(max_size=10).launch(show_api=True)