import os, json, random
import torch
import gradio as gr
import spaces
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from huggingface_hub import login, hf_hub_download
import pyreft
import pyvene as pv
from threading import Thread
from typing import Iterator
import torch.nn.functional as F

HF_TOKEN = os.environ.get("HF_TOKEN")
login(token=HF_TOKEN)

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 128  # smaller default to save memory
MAX_INPUT_TOKEN_LENGTH = 4096

css = """
#alert-message textarea {
    background-color: #e8f4ff;
    border: 1px solid #cce5ff;
    color: #084298;
    font-size: 1.1em;
    padding: 12px;
    border-radius: 4px;
    font-weight: 500;
}
"""

def load_jsonl(jsonl_path):
    jsonl_data = []
    with open(jsonl_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            jsonl_data.append(data)
    return jsonl_data

class Steer(pv.SourcelessIntervention):
    """Steer model via activation addition"""
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
                self.embed_dim, kwargs["latent_dim"], bias=False)
        self.subspace_generator = kwargs["subspace_generator"]
    def forward(self, base, source=None, subspaces=None):
        if subspaces == None:
            return base
        if subspaces["subspace_gen_inputs"] is not None:
            # we call our subspace generator to generate the subspace on-the-fly.
            raw_steering_vec = self.subspace_generator(
                subspaces["subspace_gen_inputs"]["input_ids"],
                subspaces["subspace_gen_inputs"]["attention_mask"],
            )[0]
            steering_vec = torch.tensor(subspaces["mag"]) * \
                raw_steering_vec.unsqueeze(dim=0)
            return base + steering_vec
        else:
            steering_vec = torch.tensor(subspaces["mag"]) * \
                self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
        return base + steering_vec

class RegressionWrapper(torch.nn.Module):
    def __init__(self, base_model, hidden_size, output_dim):
        super().__init__()
        self.base_model = base_model
        self.regression_head = torch.nn.Linear(hidden_size, output_dim)

    def forward(self, input_ids, attention_mask):
        outputs = self.base_model.model(
            input_ids=input_ids, 
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        last_hiddens = outputs.hidden_states[-1]
        last_token_representations = last_hiddens[:, -1]
        preds = self.regression_head(last_token_representations)
        preds = F.normalize(preds, p=2, dim=-1)
        return preds

# Check GPU
if not torch.cuda.is_available():
    print("Warning: Running on CPU, may be slow.")

# Load model & dictionary
model_id = "google/gemma-2-2b-it"
pv_model = None
tokenizer = None
concept_list = []
concept_id_map = {}
if torch.cuda.is_available():
    model = AutoModelForCausalLM.from_pretrained(
        model_id, device_map="cuda", torch_dtype=torch.bfloat16
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # Download dictionary
    weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt")
    meta_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl")
    params = torch.load(weight_path).cuda()
    md = load_jsonl(meta_path)

    concept_list = [item["concept"] for item in md]
    concept_id_map = {}

    # the reason to reindex is because there is one concept that is missing.
    concept_reindex = 0
    for item in md:
        concept_id_map[item["concept"]] = concept_reindex
        concept_reindex += 1

    # load subspace generator.
    base_tokenizer = AutoTokenizer.from_pretrained(
        f"google/gemma-2-2b", model_max_length=512)
    config = AutoConfig.from_pretrained("google/gemma-2-2b")
    base_model = AutoModelForCausalLM.from_config(config)
    
    subspace_generator_weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res-generator", filename="l20/weight.pt")
    hidden_size = base_model.config.hidden_size
    subspace_generator = RegressionWrapper(
        base_model, hidden_size, hidden_size).bfloat16().to("cuda")
    subspace_generator.load_state_dict(torch.load(subspace_generator_weight_path))
    print(f"Loading model from saved file {subspace_generator_weight_path}")
    _ = subspace_generator.eval()

    steer = Steer(
        embed_dim=params.shape[0], latent_dim=params.shape[1], 
        subspace_generator=subspace_generator)
    steer.proj.weight.data = params.float()

    pv_model = pv.IntervenableModel({
        "component": f"model.layers[20].output",
        "intervention": steer}, model=model)

terminators = [tokenizer.eos_token_id] if tokenizer else []

@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    subspaces_list: list[dict],
    max_new_tokens: int=DEFAULT_MAX_NEW_TOKENS,
) -> Iterator[str]:

    # limit to last 4 turns
    start_idx = max(0, len(chat_history) - 4)
    recent_history = chat_history[start_idx:]

    # build list of messages
    messages = []
    for rh in recent_history:
        messages.append({"role": rh["role"], "content": rh["content"]})
    messages.append({"role": "user", "content": message})

    input_ids = torch.tensor([tokenizer.apply_chat_template(
        messages, tokenize=True, add_generation_prompt=True)]).cuda()

    # trim if needed
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        yield "[Truncated prior text]\n"

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    print(subspaces_list)
    generate_kwargs = {
        "base": {"input_ids": input_ids},
        "unit_locations": None,
        "max_new_tokens": max_new_tokens,
        "intervene_on_prompt": True,
        "subspaces": [
            {
                "idx": int(subspaces_list[0]["idx"]),
                "mag": int(subspaces_list[0]["internal_mag"]),
                "subspace_gen_inputs": base_tokenizer(subspaces_list[0]["subspace_gen_text"], return_tensors="pt").to("cuda") \
                    if subspaces_list[0]["subspace_gen_text"] is not None else None
            }
        ] if subspaces_list else None,
        "streamer": streamer,
        "do_sample": True
    }

    t = Thread(target=pv_model.generate, kwargs=generate_kwargs)
    t.start()

    partial_text = []
    for token_str in streamer:
        partial_text.append(token_str)
        yield "".join(partial_text)

def filter_concepts(search_text: str):
    if not search_text.strip():
        return concept_list[:500]
    filtered = [c for c in concept_list if search_text.lower() in c.lower()]
    return filtered[:500]

def add_concept_to_list(selected_concept, user_slider_val, current_list):
    if not selected_concept:
        return current_list

    selected_concept_text = None
    if selected_concept.startswith("[New] "):
        selected_concept_text = selected_concept[6:]
        idx = 0
    else:
        idx = concept_id_map[selected_concept]
    internal_mag = user_slider_val * 50
    new_entry = {
        "text": selected_concept,
        "idx": idx,
        "display_mag": user_slider_val,
        "internal_mag": internal_mag,
        "subspace_gen_text": selected_concept_text
    }
    # Add to the beginning of the list
    current_list = [new_entry]
    return current_list

def update_dropdown_choices(search_text):
    filtered = filter_concepts(search_text)
    if not filtered or len(filtered) == 0:
        return gr.update(choices=[f"[New] {search_text}"], value=f"[New] {search_text}", interactive=True), gr.Textbox(
        label="No matching existing topics were found!", 
        value="Good news! Based on the topic you provided, we will automatically generate a steering vector. Try it out by starting a chat!",
        lines=3,
        interactive=False,
        visible=True,
        elem_id="alert-message"
    )
    # Automatically select the first matching concept
    return gr.update(
        choices=filtered,
        value=filtered[0],  # Select the first match
        interactive=True, visible=True
    ), gr.Textbox(visible=False)

with gr.Blocks(css=css, fill_height=True) as demo:
    # Remove default subspaces
    selected_subspaces = gr.State([])
    
    with gr.Row(min_height=300):
        # Left side: bigger chat area
        with gr.Column(scale=7):
            chat_interface = gr.ChatInterface(
                fn=generate,
                title="Chat with a Topic Steering Model",
                description="""Choose a topic you want the model to discuss on the right →\n\nWe intervene on Gemma-2-2B-it by adding steering vectors to the residual stream at layer 20. You can also try our **conditioned steering** model [here](https://huggingface.co/spaces/pyvene/AxBench-ReFT-cr1-16K).""",
                type="messages",
                additional_inputs=[selected_subspaces],
            )
        
        # Right side: concept management
        with gr.Column(scale=3):
            gr.Markdown("# Steer model responses")
            gr.Markdown("Search and then select a topic you want the model to discuss. The closest match will be automatically selected. If there is no match, a finetuned Gemma-2-2B model auto-steers for you!")
            # Concept Search and Selection
            with gr.Group():
                search_box = gr.Textbox(
                    label="Search topics to steer",
                    placeholder="Try: 'time travel'",
                    lines=2,
                )
                msg = gr.TextArea(visible=False)
                concept_dropdown = gr.Dropdown(
                    label="Select a topic to steer the model (Click to see more!)",
                    interactive=True,
                    allow_custom_value=False,
                )
                concept_magnitude = gr.Slider(
                    label="Steering intensity",
                    minimum=-5,
                    maximum=5,
                    step=0.1,
                    value=3,
                )
            
    # Wire up events
    # When search box changes, update dropdown AND trigger concept selection
    search_box.input(
        update_dropdown_choices,
        [search_box],
        [concept_dropdown, msg]
    ).then(  # Chain the events to automatically add the concept
        add_concept_to_list,
        [concept_dropdown, concept_magnitude, selected_subspaces],
        [selected_subspaces]
    )

    concept_dropdown.select(
        add_concept_to_list,
        [concept_dropdown, concept_magnitude, selected_subspaces],
        [selected_subspaces]
    )

    concept_dropdown.change(
        add_concept_to_list,
        [concept_dropdown, concept_magnitude, selected_subspaces],
        [selected_subspaces]
    )
    
    concept_magnitude.input(
        add_concept_to_list,
        [concept_dropdown, concept_magnitude, selected_subspaces],
        [selected_subspaces]
    )

    demo.launch(share=True)