import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import gradio as gr
import spaces

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

print("Loading finished.")

print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

STYLE = """
.custom-container {
	display: grid;
	align-items: center;
    margin: 0!important;
    overflow-y: hidden;
}
.prose ul ul {
    font-size: 10px!important;
}
.prose li {
    margin-bottom: 0!important;
}
.prose table {
    margin-bottom: 0!important;
}
.prose td, th {
    padding-left: 2px;
    padding-right: 2px;
    padding-top: 0;
    padding-bottom: 0;
    text-wrap:nowrap;
}
.tree {
	padding: 0px;
	margin: 0!important;
	box-sizing: border-box;
    font-size: 10px;
	width: 100%;
	height: auto;
	text-align: center;
    display:inline-block;
}
#root {
    display: inline-grid!important;
    width:auto!important;
    min-width: 220px;
}
.tree ul {
    padding-left: 20px;
    position: relative;
    transition: all 0.5s ease 0s;
    display: flex;
    flex-direction: column;
    gap: 10px;
    margin: 0px !important;
}
.tree li {
    display: flex;
    text-align: center;
    list-style-type: none;
    position: relative;
    padding-left: 20px;
    transition: all 0.5s ease 0s;
    flex-direction: row;
    justify-content: start;
    align-items: center;
}
.tree li::before, .tree li::after {
    content: "";
    position: absolute;
    left: 0px;
    border-left: 1px solid var(--body-text-color);
    width: 20px;
}
.tree li::before {
    top: 0;
    height:50%;
}
.tree li::after {
    top: 50%;
    height: 55%;
    bottom: auto;
    border-top: 1px solid var(--body-text-color);
}
.tree li:only-child::after, li:only-child::before {
    display: none;
}
.tree li:first-child::before, .tree li:last-child::after {
    border: 0 none;
}
.tree li:last-child::before {
	border-bottom: 1px solid var(--body-text-color);
	border-radius: 0px 0px 0px 5px;
	-webkit-border-radius: 0px 0px 0px 5px;
	-moz-border-radius: 0px 0px 0px 5px;
}
.tree li:first-child::after {
	border-radius: 5px 0 0 0;
	-webkit-border-radius: 5px 0 0 0;
	-moz-border-radius: 5px 0 0 0;
}
.tree ul ul::before {
    content: "";
    position: absolute;
    left: 0;
    top: 50%;
    border-top: 1px solid var(--body-text-color);
    width: 20px;
    height: 0;
}
.tree ul:has(> li:only-child)::before {
    width:40px;
}
.child:before {
    border-right: 2px solid var(--body-text-color);
    border-bottom: 2px solid var(--body-text-color);
    content: "";
    position: absolute;
    width: 10px;
    left: 8px;
    height: 10px;
    top: 50%;
    margin-top: -5px;
    transform: rotate(315deg);
}
.tree li a {
	border: 1px solid var(--body-text-color);
	padding: 5px;
	border-radius: 5px;
	text-decoration-line: none;
	border-radius: 5px;
	transition: .5s;
    display: flex;
    align-items: center;
    justify-content: space-between;
    overflow: hidden;
}
.tree li a span {
	padding: 5px;
	font-size: 12px;
	letter-spacing: 1px;
	font-weight: 500;
}
/*Hover-Section*/
.tree li a:hover, .tree li a:hover+ul li a {
	background: var(--primary-500);
}
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before, .tree li a:hover+ul a::before {
	border-color: var(--primary-500);
}
.chosen-token {
    background-color: var(--primary-400);
}
.chosen-token td, .chosen-token tr {
    color: black!important;
}
.end-of-text {
    width:auto!important;
}
.nonfinal {
    width:280px;
    min-width: 280px;
}
.selected-sequence {
    background-color: var(--secondary-500);
}
.nonselected-sequence {
    background-color: var(--primary-500);
}
.nopadding {
    padding-left: 0;
}
"""


def clean(s):
    return s.replace("\n", r"\n").replace("\t", r"\t").strip()


def generate_markdown_table(
    scores, previous_cumul_score, score_divider, top_k=4, chosen_tokens=None
):
    markdown_table = """
    <table>
        <tr>
            <th><b>Token</b></th>
            <th><b>Step score</b></th>
            <th><b>Total score</b></th>
        </tr>"""
    for token_idx in np.array(np.argsort(scores)[-top_k:])[::-1]:
        token = tokenizer.decode([token_idx])
        item_class = ""
        if chosen_tokens and token in chosen_tokens:
            item_class = "chosen-token"
        markdown_table += f"""
        <tr class={item_class}>
            <td>{clean(token)}</td>
            <td>{scores[token_idx]:.4f}</td>
            <td>{(scores[token_idx] + previous_cumul_score)/score_divider:.4f}</td>
        </tr>"""
    markdown_table += """
    </table>"""
    return markdown_table


def generate_nodes(node, step):
    """Recursively generate HTML for the tree nodes."""
    token = tokenizer.decode([node.current_token_ix])

    if node.is_final:
        if node.is_selected_sequence:
            selected_class = "selected-sequence"
        else:
            selected_class = "nonselected-sequence"
        return f"<li> <a class='end-of-text child {selected_class}'> <span> <b>{clean(token)}</b> <br>Total score: {node.total_score:.2f}</span> </a> </li>"

    html_content = (
        f"<li> <a class='nonfinal child'> <span> <b>{clean(token)}</b> </span>"
    )
    if node.table is not None:
        html_content += node.table
    html_content += "</a>"

    if len(node.children.keys()) > 0:
        html_content += "<ul> "
        for token_ix, subnode in node.children.items():
            html_content += generate_nodes(subnode, step=step + 1)
        html_content += "</ul>"
    html_content += "</li>"

    return html_content


def generate_html(start_sentence, original_tree):
    html_output = f"""<div class="custom-container">
				<div class="tree">
                <ul> <li> <a id='root' class="nopadding"> <span> <b>{start_sentence}</b> </span> {original_tree.table} </a>"""
    html_output += "<ul> "
    for subnode in original_tree.children.values():
        html_output += generate_nodes(subnode, step=1)
    html_output += "</ul>"
    html_output += """
        </li> </ul>
        </div>
    </body>
    """
    return html_output


import pandas as pd
from typing import Dict
from dataclasses import dataclass


@dataclass
class BeamNode:
    current_token_ix: int
    cumulative_score: float
    children_score_divider: float
    table: str
    current_sequence: str
    children: Dict[int, "BeamNode"]
    total_score: float
    is_final: bool
    is_selected_sequence: bool


def generate_beams(n_beams, start_sentence, scores, length_penalty, decoded_sequences, beam_indexes_source):
    original_tree = BeamNode(
        cumulative_score=0,
        current_token_ix=None,
        table=None,
        current_sequence=start_sentence,
        children={},
        children_score_divider=(1 ** length_penalty),
        total_score=None,
        is_final=False,
        is_selected_sequence=False,
    )
    beam_trees = [original_tree] * n_beams
    generation_length = len(scores)

    for step, step_scores in enumerate(scores):

        # Gather all possible descendants for each beam
        (
            top_token_indexes,
            top_cumulative_scores,
            beam_indexes,
            current_sequence,
            top_tokens,
            token_scores,
        ) = ([], [], [], [], [], [])

        score_idx = 0
        for beam_ix in range(len(beam_trees)):
            current_beam = beam_trees[beam_ix]

            # skip if the beam is already final
            if current_beam.is_final:
                continue
                
            # Get top cumulative scores for the current beam
            current_top_token_indexes = list(
                np.array(scores[step][score_idx].argsort()[-n_beams:])[::-1]
            )
            top_token_indexes += current_top_token_indexes
            token_scores += list(np.array(scores[step][score_idx][current_top_token_indexes]))
            top_cumulative_scores += list(
                np.array(scores[step][score_idx][current_top_token_indexes])
                + current_beam.cumulative_score
            )
            beam_indexes += [beam_ix] * n_beams
            current_sequence += [beam_trees[beam_ix].current_sequence] * n_beams
            top_tokens += [tokenizer.decode([el]) for el in current_top_token_indexes]
            score_idx += 1

        top_df = pd.DataFrame.from_dict(
            {
                "token_index": top_token_indexes,
                "cumulative_score": top_cumulative_scores,
                "beam_index": beam_indexes,
                "current_sequence": current_sequence,
                "token": top_tokens,
                "token_score": token_scores,
            }
        )
        maxes = top_df.groupby(["token_index", "current_sequence"])[
            "cumulative_score"
        ].idxmax()

        top_df = top_df.loc[maxes]

        # Sort all top probabilities and keep top n_beams * 2 (* 2 because each beam may end this iteration, and we
        # want to keep at least `n_beams` beams alive)
        top_df_selected = top_df.sort_values("cumulative_score", ascending=False).iloc[
            :n_beams * 2
        ]
        beams_to_keep = 0
        unfinished_beams = 0
        for _, row in top_df_selected.iterrows():
            beams_to_keep += 1
            current_token_choice_ix = row["token_index"]
            is_final = step == len(scores) - 1 or current_token_choice_ix == tokenizer.eos_token_id
            if not is_final:
                unfinished_beams += 1
            if unfinished_beams >= n_beams:
                break
            if step == generation_length - 1 and beams_to_keep == n_beams:
                break
        top_df_selected_filtered = top_df_selected.iloc[:beams_to_keep]

        # Write the scores table in each beam tree
        score_idx = 0
        for beam_ix in range(len(beam_trees)):
            current_beam = beam_trees[beam_ix]
            if current_beam.table is None:
                selected_tokens = top_df_selected_filtered.loc[
                    top_df_selected_filtered["current_sequence"] == current_beam.current_sequence
                ]
                markdown_table = generate_markdown_table(
                    step_scores[score_idx, :],
                    current_beam.cumulative_score,
                    current_beam.children_score_divider,
                    chosen_tokens=list(selected_tokens["token"].values),
                )
                beam_trees[beam_ix].table = markdown_table
            if not current_beam.is_final:
                score_idx = min(score_idx + 1, n_beams - 1)

        # Add new children to each beam
        cumulative_scores = [beam.cumulative_score for beam in beam_trees]
        for _, row in top_df_selected_filtered.iterrows():
            # Update the source tree
            source_beam_ix = int(row["beam_index"])
            current_token_choice_ix = row["token_index"]
            current_token_choice = tokenizer.decode([current_token_choice_ix])
            token_scores = row["token_score"]

            cumulative_score = cumulative_scores[source_beam_ix] + np.asarray(token_scores)
            current_sequence = (
                beam_trees[source_beam_ix].current_sequence + current_token_choice
            )
            is_final = step == len(scores) - 1 or current_token_choice_ix == tokenizer.eos_token_id
            beam_trees[source_beam_ix].children[current_token_choice_ix] = BeamNode(
                current_token_ix=current_token_choice_ix,
                table=None,
                children={},
                current_sequence=current_sequence,
                cumulative_score=cumulative_score,
                total_score=cumulative_score / (step + 1 ** length_penalty),
                children_score_divider=((step + 2) ** length_penalty),
                is_final=is_final,
                is_selected_sequence=(
                    current_sequence.replace("<|endoftext|>", "")
                    in [el.replace("<|endoftext|>", "") for el in decoded_sequences]
                ),
            )

        # Swap all beams by descending cumul score, so that n°1 has the highest cumulative score, and so on
        beam_trees = [
            beam_trees[int(top_df_selected_filtered.iloc[beam_ix]["beam_index"])]
            for beam_ix in range(beams_to_keep)
        ]

        # Advance all beams by one token
        for beam_ix in range(beams_to_keep):
            current_token_choice_ix = top_df_selected_filtered.iloc[beam_ix]["token_index"]
            beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix]

        print(f"Step {step}, beams kept: {beams_to_keep}")

    return original_tree

@spaces.GPU
def get_beam_search_html(
    input_text, number_steps, number_beams, length_penalty, num_return_sequences
):
    inputs = tokenizer([input_text], return_tensors="pt")

    outputs = model.generate(
        **inputs,
        max_new_tokens=number_steps,
        num_beams=number_beams,
        num_return_sequences=num_return_sequences,
        return_dict_in_generate=True,
        length_penalty=length_penalty,
        output_scores=True,
        do_sample=False,
    )
    markdown = "The conclusive sequences are the ones that end in an `<|endoftext|>` token or at the end of generation."
    markdown += "\n\nThey are ranked by their scores, as given by the formula `score = cumulative_score / (output_length ** length_penalty)`.\n\n"
    markdown += "Only the top `num_beams` scoring sequences are returned: in the tree they are highlighted in **<span style='color:var(--secondary-500)!important'>blue</span>**."
    markdown += " The non-selected sequences are also shown in the tree, highlighted in **<span style='color:var(--primary-500)!important'>yellow</span>**."
    markdown += "\n#### <span style='color:var(--secondary-500)!important'>Output sequences:</span>"
    # Sequences are padded anyway so you can batch decode them
    decoded_sequences = tokenizer.batch_decode(outputs.sequences)
    for i, sequence in enumerate(decoded_sequences):
        markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`"

    original_tree = generate_beams(
        number_beams,
        input_text,
        outputs.scores[:],
        length_penalty,
        decoded_sequences,
        outputs.beam_indices,
    )
    html = generate_html(input_text, original_tree)
    return html, markdown


def change_num_return_sequences(n_beams):
    return gr.Slider(
        label="Number of sequences", minimum=1, maximum=n_beams, step=1, value=n_beams
    )


with gr.Blocks(
    theme=gr.themes.Soft(
        primary_hue=gr.themes.colors.yellow,
        secondary_hue=gr.themes.colors.blue,
    ),
    css=STYLE,
) as demo:
    gr.Markdown(
        """# <span style='color:var(--primary-500)!important'>Beam Search Visualizer</span>

Play with the parameters below to understand how beam search decoding works!

#### <span style='color:var(--primary-500)!important'>Parameters:</span>
- **Sentence to decode from** (`inputs`): the input sequence to your decoder.
- **Number of steps** (`max_new_tokens`): the number of tokens to generate.
- **Number of beams** (`num_beams`): the number of beams to use.
- **Length penalty** (`length_penalty`): the length penalty to apply to outputs. `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
This parameter will not impact the beam search paths, but only influence the choice of sequences in the end towards longer or shorter sequences.
- **Number of return sequences** (`num_return_sequences`): the number of sequences to be returned at the end of generation. Should be `<= num_beams`.
"""
    )
    text = gr.Textbox(
        label="Sentence to decode from",
        value="Conclusion: thanks a lot. That's all for today",
    )
    with gr.Row():
        n_steps = gr.Slider(
            label="Number of steps", minimum=1, maximum=12, step=1, value=5
        )
        n_beams = gr.Slider(
            label="Number of beams", minimum=1, maximum=4, step=1, value=4
        )
        length_penalty = gr.Slider(
            label="Length penalty", minimum=-3, maximum=3, step=0.5, value=1
        )
        num_return_sequences = gr.Slider(
            label="Number of return sequences", minimum=1, maximum=4, step=1, value=3
        )

    n_beams.change(
        fn=change_num_return_sequences, inputs=n_beams, outputs=num_return_sequences
    )
    button = gr.Button()
    out_html = gr.Markdown()
    out_markdown = gr.Markdown()
    button.click(
        get_beam_search_html,
        inputs=[text, n_steps, n_beams, length_penalty, num_return_sequences],
        outputs=[out_html, out_markdown],
    )

demo.launch()