File size: 3,593 Bytes
3c3eabb
 
5204c67
3c3eabb
 
f026dba
ac2cf21
 
 
5204c67
 
ac2cf21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5204c67
ac2cf21
 
 
 
 
 
5204c67
ac2cf21
edce2eb
ac2cf21
 
 
 
5204c67
ac2cf21
5204c67
 
edce2eb
5204c67
ac2cf21
 
 
 
 
 
5204c67
ac2cf21
 
 
5204c67
ac2cf21
 
 
 
5204c67
e2ed84a
5204c67
 
 
ade3ffb
ac2cf21
 
 
ade3ffb
 
5204c67
 
ade3ffb
 
 
 
 
 
ac2cf21
 
 
3c3eabb
 
5204c67
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr

from transformers import GPT2Tokenizer, AutoModelForCausalLM
import numpy as np

MODEL_NAME = "gpt2"

if __name__ == "__main__":
    # Define your model and your tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = model.config.eos_token_id

    # Define your color-coding labels; if prob > x, then label = y; Sorted in descending probability order!
    probs_to_label = [
        (0.1, "p >= 10%"),
        (0.01, "p >= 1%"),
        (1e-20, "p < 1%"),
    ]

    label_to_color = {
        "p >= 10%": "green",
        "p >= 1%": "yellow",
        "p < 1%": "red"
    }


    def get_tokens_and_labels(prompt):
        """
        Given the prompt (text), return a list of tuples (decoded_token, label)
        """
        inputs = tokenizer([prompt], return_tensors="pt")
        outputs = model.generate(
            **inputs, max_new_tokens=50, return_dict_in_generate=True, output_scores=True, do_sample=True
        )
        # Important: don't forget to set `normalize_logits=True` to obtain normalized probabilities (i.e. sum(p) = 1)
        transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
        transition_proba = np.exp(transition_scores)
        # We only have scores for the generated tokens, so pop out the prompt tokens
        input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
        generated_tokens = outputs.sequences[:, input_length:]

        # Initialize the highlighted output with the prompt, which will have no color label
        highlighted_out = [(tokenizer.decode(token), None) for token in inputs.input_ids]
        # Get the (decoded_token, label) pairs for the generated tokens
        for token, proba in zip(generated_tokens[0], transition_proba[0]):
            this_label = None
            assert 0. <= proba <= 1.0
            for min_proba, label in probs_to_label:
                if proba >= min_proba:
                    this_label = label
                    break
            highlighted_out.append((tokenizer.decode(token), this_label))

        return highlighted_out


    demo = gr.Blocks()
    with demo:
        gr.Markdown(
            """
            # ๐ŸŒˆ Color Coded Text Generation ๐ŸŒˆ
            This is a demo of how you can obtain the probabilities of each generated token, and use them to
            color code the model output.
            Feel free to clone this demo and modify it to your needs ๐Ÿค—
            Internally, it relies on [`compute_transition_scores`](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores),
            which was added in `transformers` v4.26.0.
            """
        )

        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(label="Prompt", lines=3, value="Today is")
                button = gr.Button(f"Generate with {MODEL_NAME}, using sampling!")
            with gr.Column():
                highlighted_text = gr.HighlightedText(
                    label="Highlighted generation",
                    combine_adjacent=True,
                    show_legend=True,
                ).style(color_map=label_to_color)

        button.click(get_tokens_and_labels, inputs=prompt, outputs=highlighted_text)

if __name__ == "__main__":
    demo.launch()