SuperPrompt-v1 / app.py
Nick088's picture
Update app.py
b0413b4 verified
raw
history blame
3.23 kB
import gradio as gr
import torch
import random
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=torch.float16)
if torch.cuda.is_available():
device = "cuda"
print("Using GPU")
else:
device = "cpu"
print("Using CPU")
model.to(device)
def generate(
prompt,
history,
max_new_tokens,
repetition_penalty,
temperature,
top_p,
top_k,
random_seed,
seed,
):
input_text = f"{prompt}, {history}"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
if random_seed:
seed = random.randint(1, 100000)
torch.manual_seed(seed)
else:
torch.manual_seed(seed)
outputs = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
better_prompt = tokenizer.decode(outputs[0])
return better_prompt
additional_inputs = [
gr.Slider(
value=512,
minimum=250,
maximum=512,
step=1,
interactive=True,
label="Max New Tokens",
info="The maximum numbers of new tokens, controls how long is the output",
),
gr.Slider(
value=1.2,
minimum=0,
maximum=2,
step=0.05,
interactive=True,
label="Repetition Penalty",
info="Penalize repeated tokens, making the AI repeat less itself",
),
gr.Slider(
value=0.5,
minimum=0,
maximum=1,
step=0.05,
interactive=True,
label="Temperature",
info="Higher values produce more diverse outputs",
),
gr.Slider(
value=1,
minimum=0,
maximum=2,
step=0.05,
interactive=True,
label="Top P",
info="Higher values sample more low-probability tokens",
),
gr.Slider(
value=1,
minimum=1,
maximum=100,
step=1,
interactive=True,
label="Top K",
info="Higher k means more diverse outputs by considering a range of tokens",
),
gr.Checkbox(
value=False,
label="Use Random Seed",
info="Check to use a random seed which is a start point for the generation process",
),
gr.Number(
value=42,
interactive=True,
label="Manual Seed",
info="A starting point to initiate the generation process"
),
]
examples = [
[
"Expand the following prompt to add more detail: A storefront with 'Text to Image' written on it.",
512,
1.2,
0.5,
1,
50,
False,
42,
]
]
gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(
show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"
),
additional_inputs=additional_inputs,
title="SuperPrompt-v1",
description="Make your prompts more detailed!",
examples=examples,
concurrency_limit=20,
).launch(show_api=False)