Spaces:
Running
Running
File size: 3,233 Bytes
f97bf68 69855af f97bf68 b0413b4 f97bf68 244f082 ead2968 244f082 f97bf68 244f082 f97bf68 244f082 ead2968 69855af a81c6ef ead2968 244f082 f97bf68 244f082 8caf209 f97bf68 69855af 244f082 ead2968 a81c6ef 6753cc6 244f082 f97bf68 244f082 f97bf68 244f082 f97bf68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import gradio as gr
import torch
import random
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=torch.float16)
if torch.cuda.is_available():
device = "cuda"
print("Using GPU")
else:
device = "cpu"
print("Using CPU")
model.to(device)
def generate(
prompt,
history,
max_new_tokens,
repetition_penalty,
temperature,
top_p,
top_k,
random_seed,
seed,
):
input_text = f"{prompt}, {history}"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
if random_seed:
seed = random.randint(1, 100000)
torch.manual_seed(seed)
else:
torch.manual_seed(seed)
outputs = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
better_prompt = tokenizer.decode(outputs[0])
return better_prompt
additional_inputs = [
gr.Slider(
value=512,
minimum=250,
maximum=512,
step=1,
interactive=True,
label="Max New Tokens",
info="The maximum numbers of new tokens, controls how long is the output",
),
gr.Slider(
value=1.2,
minimum=0,
maximum=2,
step=0.05,
interactive=True,
label="Repetition Penalty",
info="Penalize repeated tokens, making the AI repeat less itself",
),
gr.Slider(
value=0.5,
minimum=0,
maximum=1,
step=0.05,
interactive=True,
label="Temperature",
info="Higher values produce more diverse outputs",
),
gr.Slider(
value=1,
minimum=0,
maximum=2,
step=0.05,
interactive=True,
label="Top P",
info="Higher values sample more low-probability tokens",
),
gr.Slider(
value=1,
minimum=1,
maximum=100,
step=1,
interactive=True,
label="Top K",
info="Higher k means more diverse outputs by considering a range of tokens",
),
gr.Checkbox(
value=False,
label="Use Random Seed",
info="Check to use a random seed which is a start point for the generation process",
),
gr.Number(
value=42,
interactive=True,
label="Manual Seed",
info="A starting point to initiate the generation process"
),
]
examples = [
[
"Expand the following prompt to add more detail: A storefront with 'Text to Image' written on it.",
512,
1.2,
0.5,
1,
50,
False,
42,
]
]
gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(
show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"
),
additional_inputs=additional_inputs,
title="SuperPrompt-v1",
description="Make your prompts more detailed!",
examples=examples,
concurrency_limit=20,
).launch(show_api=False) |