Spaces:
Runtime error
Runtime error
File size: 7,037 Bytes
658b022 ef0bdc3 61e66c3 658b022 ef0bdc3 61e66c3 ef0bdc3 61e66c3 ef0bdc3 61e66c3 ef0bdc3 61e66c3 ef0bdc3 61e66c3 ef0bdc3 61e66c3 ef0bdc3 658b022 61e66c3 658b022 61e66c3 ef0bdc3 61e66c3 ef0bdc3 61e66c3 658b022 ef0bdc3 61e66c3 ef0bdc3 658b022 61e66c3 658b022 ef0bdc3 658b022 ef0bdc3 61e66c3 ef0bdc3 658b022 ef0bdc3 61e66c3 ef0bdc3 61e66c3 658b022 ef0bdc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import gradio as gr
from transformers import AutoTokenizer
from transformers import pipeline
from utils import format_moves
import pandas as pd
model_checkpoint = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
generate = pipeline("text-generation",
model="arjunpatel/distilgpt2-finetuned-pokemon-moves",
tokenizer=tokenizer)
# load in the model
seed_text = "This move is called "
import tensorflow as tf
tf.random.set_seed(0)
# need a function to sanitize imputs
# - remove extra spaces
# - make sure each word is capitalized
# - format the moves such that it's clearer when each move is listed
# - play with the max length parameter abit, and try to remove sentences that don't end in periods.
def update_history(df, move_name, move_desc, generation, parameters):
# needs to format each move description with new lines to cut down on width
new_row = [{"Move Name": move_name,
"Move Description": move_desc,
"Generation Type": generation,
"Parameters": parameters}]
return pd.concat([df, pd.DataFrame(new_row)])
def create_move(move, history):
generated_move = format_moves(generate(seed_text + move, num_return_sequences=1))
return generated_move, update_history(history, move, generated_move,
"baseline", "None")
def create_greedy_search_move(move, history):
generated_move = format_moves(generate(seed_text + move, do_sample=False))
return generated_move, update_history(history, move, generated_move,
"greedy", "None")
def create_beam_search_move(move, num_beams, history):
generated_move = format_moves(generate(seed_text + move, num_beams=num_beams,
num_return_sequences=1,
do_sample=False, early_stopping=True))
return generated_move, update_history(history, move, generated_move,
"beam", {"num_beams": 2})
def create_sampling_search_move(move, do_sample, temperature, history):
generated_move = format_moves(generate(seed_text + move, do_sample=do_sample, temperature=float(temperature),
num_return_sequences=1, topk=0))
return generated_move, update_history(history, move, generated_move,
"temperature", {"do_sample": do_sample,
"temperature": temperature})
def create_top_search_move(move, topk, topp, history):
generated_move = format_moves(generate(
seed_text + move,
do_sample=True,
num_return_sequences=1,
top_k=topk,
top_p=topp,
force_word_ids=tokenizer.encode("The user", return_tensors='tf')))
return generated_move, update_history(history, move, generated_move,
"top", {"top k": topk,
"top p": topp})
demo = gr.Blocks()
with demo:
gr.Markdown("<h1><center>What's that Pokemon Move?</center></h1>")
gr.Markdown(
"This Gradio demo is a small GPT-2 model fine-tuned on a dataset of Pokemon moves! It'll generate a move description given a name.")
gr.Markdown("Enter a two to three word Pokemon Move name of your imagination below!")
with gr.Tabs():
with gr.TabItem("Standard Generation"):
with gr.Row():
text_input_baseline = gr.Textbox(label="Move",
placeholder="Type a two or three word move name here! Try \"Wonder Shield\"!")
text_output_baseline = gr.Textbox(label="Move Description",
placeholder="Leave this blank!")
text_button_baseline = gr.Button("Create my move!")
with gr.TabItem("Greedy Search"):
gr.Markdown("This tab lets you learn about using greedy search!")
with gr.Row():
text_input_greedy = gr.Textbox(label="Move")
text_output_greedy = gr.Textbox(label="Move Description")
text_button_greedy = gr.Button("Create my move!")
with gr.TabItem("Beam Search"):
gr.Markdown("This tab lets you learn about using beam search!")
with gr.Row():
num_beams = gr.Slider(minimum=2, maximum=10, value=2, step=1,
label="Number of Beams")
text_input_beam = gr.Textbox(label="Move")
text_output_beam = gr.Textbox(label="Move Description")
text_button_beam = gr.Button("Create my move!")
with gr.TabItem("Sampling and Temperature Search"):
gr.Markdown("This tab lets you experiment with adjusting the temperature of the generator")
with gr.Row():
temperature = gr.Slider(minimum=0.3, maximum=4.0, value=1.0, step=0.1,
label="Temperature")
sample_boolean = gr.Checkbox(label="Enable Sampling?")
text_input_temp = gr.Textbox(label="Move")
text_output_temp = gr.Textbox(label="Move Description")
text_button_temp = gr.Button("Create my move!")
with gr.TabItem("Top K and Top P Sampling"):
gr.Markdown("This tab lets you learn about Top K and Top P Sampling")
with gr.Row():
topk = gr.Slider(minimum=10, maximum=100, value=0, step=5,
label="Top K")
topp = gr.Slider(minimum=0.10, maximum=0.95, value=1, step=0.05,
label="Top P")
text_input_top = gr.Textbox(label="Move")
text_output_top = gr.Textbox(label="Move Description")
text_button_top = gr.Button("Create my move!")
with gr.Box():
# Displays a dataframe with the history of moves generated, with parameters
history = gr.Dataframe(headers=["Move Name", "Move Description", "Generation Type", "Parameters"])
text_button_baseline.click(create_move, inputs=[text_input_baseline, history],
outputs=[text_output_baseline, history])
text_button_greedy.click(create_greedy_search_move, inputs=[text_input_greedy, history],
outputs=[text_output_greedy, history])
text_button_temp.click(create_sampling_search_move, inputs=[text_input_temp, sample_boolean, temperature, history],
outputs=[text_output_temp, history])
text_button_beam.click(create_beam_search_move, inputs=[text_input_beam, num_beams, history],
outputs=[text_output_beam, history])
text_button_top.click(create_top_search_move, inputs=[text_input_top, topk, topp, history],
outputs=[text_output_top, history])
demo.launch(share=True)
|