File size: 4,758 Bytes
3c3eabb
 
7a24b1b
3c3eabb
92dde49
 
3c3eabb
d51d76f
ac2cf21
f998e8f
ac2cf21
 
7a24b1b
5204c67
ac2cf21
 
 
b5b0c27
ac2cf21
 
 
 
 
 
 
 
 
 
 
 
 
 
5204c67
ac2cf21
 
 
 
 
805081b
 
 
 
 
 
 
ac2cf21
6ef22f3
ac2cf21
edce2eb
ac2cf21
 
 
 
5204c67
ac2cf21
5204c67
 
edce2eb
5204c67
ac2cf21
 
 
 
 
 
5204c67
ac2cf21
 
 
5204c67
ac2cf21
 
 
 
075c5ca
6d80e45
075c5ca
2f4db23
4d5cbc9
a1aa766
de15f45
4d5cbc9
 
 
2f4db23
6d80e45
 
ac2cf21
 
 
ade3ffb
 
d7755a4
a1aa766
ade3ffb
 
 
 
 
 
ac2cf21
 
 
3c3eabb
 
5204c67
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr

from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor

MODEL_NAME = "gpt2"


if __name__ == "__main__":
    # Define your model and your tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = model.config.eos_token_id
    model.to_bettertransformer()

    # Define your color-coding labels; if prob > x, then label = y; Sorted in descending probability order!
    probs_to_label = [
        (0.1, "p >= 10%"),
        (0.01, "p >= 1%"),
        (1e-20, "p < 1%"),
    ]

    label_to_color = {
        "p >= 10%": "green",
        "p >= 1%": "yellow",
        "p < 1%": "red"
    }


    def get_tokens_and_labels(prompt):
        """
        Given the prompt (text), return a list of tuples (decoded_token, label)
        """
        inputs = tokenizer([prompt], return_tensors="pt")

        # Load json grammar and create a GrammarConstrainedLogitsProcessor for each call
        with open("json_minimal.ebnf", "r") as file:
            grammar_str = file.read()
        grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
        grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

        outputs = model.generate(
            **inputs, max_new_tokens=50, repetition_penalty=1, return_dict_in_generate=True, output_scores=True, logits_processor=[grammar_processor]
        )
        # Important: don't forget to set `normalize_logits=True` to obtain normalized probabilities (i.e. sum(p) = 1)
        transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
        transition_proba = np.exp(transition_scores)
        # We only have scores for the generated tokens, so pop out the prompt tokens
        input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
        generated_tokens = outputs.sequences[:, input_length:]

        # Initialize the highlighted output with the prompt, which will have no color label
        highlighted_out = [(tokenizer.decode(token), None) for token in inputs.input_ids]
        # Get the (decoded_token, label) pairs for the generated tokens
        for token, proba in zip(generated_tokens[0], transition_proba[0]):
            this_label = None
            assert 0. <= proba <= 1.0
            for min_proba, label in probs_to_label:
                if proba >= min_proba:
                    this_label = label
                    break
            highlighted_out.append((tokenizer.decode(token), this_label))

        return highlighted_out


    demo = gr.Blocks()
    with demo:
        gr.Markdown(
            """
            # 👻 Transformers-CFG JSON Demo
            This is a demo of how you can constrain the output of a GPT-2 model to be a **valid** JSON string(**up to truncation**).
            Here we use a simple JSON grammar to constrain the output of the model.
            The grammar is defined in `json_minimal.ebnf` and is written in the **Extended Backus-Naur Form (EBNF)**.
            
            Internally, it relies on the library [`transformers-cfg`](https://github.com/epfl-dlab/transformers-CFG).
            For demo purpose, gpt2 is used, but you can use much larger models for better performance.
            
            The inference is a bit slow because of the inference is run on **CPU(~20s for 30 tokens)**.
            The constraint itself **doesn't** introduce significant overhead to the inference.
            
            The output may be **truncated** to 30 tokens due to the limitation of the maximum length of the output.
            In practice, with a decent `max_length` parameter, your JSON output will be **complete** and **valid**.
            """
        )

        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(label="Prompt", lines=3, value="This is a valid json string describing a Pokémon character:")
                button = gr.Button(f"Generate with json object using {MODEL_NAME}!")
            with gr.Column():
                highlighted_text = gr.HighlightedText(
                    label="Highlighted generation",
                    combine_adjacent=True,
                    show_legend=True,
                ).style(color_map=label_to_color)

        button.click(get_tokens_and_labels, inputs=prompt, outputs=highlighted_text)

if __name__ == "__main__":
    demo.launch()