Nick088's picture
Updated description, fixed manual seed not being hidden when using random seed
e191ed9 verified
raw
history blame
3.25 kB
import gradio as gr
import torch
import random
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=torch.float16)
if torch.cuda.is_available():
device = "cuda"
print("Using GPU")
else:
device = "cpu"
print("Using CPU")
model.to(device)
def generate(
system_prompt,
prompt,
max_new_tokens,
repetition_penalty,
temperature,
top_p,
top_k,
random_seed,
seed,
):
input_text = f"{system_prompt}, {prompt}"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
if random_seed:
seed = random.randint(1, 100000)
torch.manual_seed(seed)
else:
torch.manual_seed(seed)
outputs = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
better_prompt = tokenizer.decode(outputs[0])
return better_prompt
prompt = gr.Textbox(label="Prompt", interactive=True)
system_prompt = gr.Textbox(label="System Prompt", interactive=True)
max_new_tokens = gr.Slider(value=512, minimum=250, maximum=512, step=1, interactive=True, label="Max New Tokens", info="The maximum numbers of new tokens, controls how long is the output")
repetition_penalty = gr.Slider(value=1.2, minimum=0, maximum=2, step=0.05, interactive=True, label="Repetition Penalty", info="Penalize repeated tokens, making the AI repeat less itself")
temperature = gr.Slider(value=0.5, minimum=0, maximum=1, step=0.05, interactive=True, label="Temperature", info="Higher values produce more diverse outputs")
top_p = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, interactive=True, label="Top P", info="Higher values sample more low-probability tokens")
top_k = gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens")
use_random_seed = gr.Checkbox(value=False, label="Use Random Seed", info="Check to use a random seed which is a start point for the generation process")
manual_seed = gr.Number(value=42, interactive=True, label="Manual Seed", info="A starting point to initiate the generation process", visible={False if use_random_seed else True})
examples = [
[
"A storefront with 'Text to Image' written on it.",
"Expand the following prompt to add more detail:",
512,
1.2,
0.5,
1,
50,
False,
42,
]
]
gr.Interface(
fn=generate,
inputs=[prompt, system_prompt, max_new_tokens, repetition_penalty, temperature, top_p, top_k, use_random_seed, manual_seed],
outputs=gr.Textbox(label="Better Prompt", interactive=True),
title="SuperPrompt-v1",
description="Make your prompts more detailed!\nModel used: https://huggingface.co/roborovski/superprompt-v1\nHugging Face Space made by [Nick088}(https://linktr.ee/Nick088)",
examples=examples,
live=True,
concurrency_limit=20,
).launch(show_api=False)