File size: 4,758 Bytes
3c3eabb 7a24b1b 3c3eabb 92dde49 3c3eabb d51d76f ac2cf21 f998e8f ac2cf21 7a24b1b 5204c67 ac2cf21 b5b0c27 ac2cf21 5204c67 ac2cf21 805081b ac2cf21 6ef22f3 ac2cf21 edce2eb ac2cf21 5204c67 ac2cf21 5204c67 edce2eb 5204c67 ac2cf21 5204c67 ac2cf21 5204c67 ac2cf21 075c5ca 6d80e45 075c5ca 2f4db23 4d5cbc9 a1aa766 de15f45 4d5cbc9 2f4db23 6d80e45 ac2cf21 ade3ffb d7755a4 a1aa766 ade3ffb ac2cf21 3c3eabb 5204c67 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
MODEL_NAME = "gpt2"
if __name__ == "__main__":
# Define your model and your tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = model.config.eos_token_id
model.to_bettertransformer()
# Define your color-coding labels; if prob > x, then label = y; Sorted in descending probability order!
probs_to_label = [
(0.1, "p >= 10%"),
(0.01, "p >= 1%"),
(1e-20, "p < 1%"),
]
label_to_color = {
"p >= 10%": "green",
"p >= 1%": "yellow",
"p < 1%": "red"
}
def get_tokens_and_labels(prompt):
"""
Given the prompt (text), return a list of tuples (decoded_token, label)
"""
inputs = tokenizer([prompt], return_tensors="pt")
# Load json grammar and create a GrammarConstrainedLogitsProcessor for each call
with open("json_minimal.ebnf", "r") as file:
grammar_str = file.read()
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
outputs = model.generate(
**inputs, max_new_tokens=50, repetition_penalty=1, return_dict_in_generate=True, output_scores=True, logits_processor=[grammar_processor]
)
# Important: don't forget to set `normalize_logits=True` to obtain normalized probabilities (i.e. sum(p) = 1)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
transition_proba = np.exp(transition_scores)
# We only have scores for the generated tokens, so pop out the prompt tokens
input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
generated_tokens = outputs.sequences[:, input_length:]
# Initialize the highlighted output with the prompt, which will have no color label
highlighted_out = [(tokenizer.decode(token), None) for token in inputs.input_ids]
# Get the (decoded_token, label) pairs for the generated tokens
for token, proba in zip(generated_tokens[0], transition_proba[0]):
this_label = None
assert 0. <= proba <= 1.0
for min_proba, label in probs_to_label:
if proba >= min_proba:
this_label = label
break
highlighted_out.append((tokenizer.decode(token), this_label))
return highlighted_out
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
# 👻 Transformers-CFG JSON Demo
This is a demo of how you can constrain the output of a GPT-2 model to be a **valid** JSON string(**up to truncation**).
Here we use a simple JSON grammar to constrain the output of the model.
The grammar is defined in `json_minimal.ebnf` and is written in the **Extended Backus-Naur Form (EBNF)**.
Internally, it relies on the library [`transformers-cfg`](https://github.com/epfl-dlab/transformers-CFG).
For demo purpose, gpt2 is used, but you can use much larger models for better performance.
The inference is a bit slow because of the inference is run on **CPU(~20s for 30 tokens)**.
The constraint itself **doesn't** introduce significant overhead to the inference.
The output may be **truncated** to 30 tokens due to the limitation of the maximum length of the output.
In practice, with a decent `max_length` parameter, your JSON output will be **complete** and **valid**.
"""
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", lines=3, value="This is a valid json string describing a Pokémon character:")
button = gr.Button(f"Generate with json object using {MODEL_NAME}!")
with gr.Column():
highlighted_text = gr.HighlightedText(
label="Highlighted generation",
combine_adjacent=True,
show_legend=True,
).style(color_map=label_to_color)
button.click(get_tokens_and_labels, inputs=prompt, outputs=highlighted_text)
if __name__ == "__main__":
demo.launch()
|