Spaces:
Runtime error
Runtime error
import functools as ft | |
import gradio as gr | |
import torch | |
import transformers | |
from transformers import T5ForConditionalGeneration, T5Tokenizer | |
tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( | |
"models/roborovski/superprompt-v1", local_files_only=True | |
) | |
model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained( | |
"models/roborovski/superprompt-v1", local_files_only=True | |
) | |
def super_prompt(text: str, seed: int, max_new_tokens: int, prompt: str) -> str: | |
transformers.set_seed(seed) | |
if max_new_tokens <= 0: | |
max_new_tokens = 150 | |
with torch.inference_mode(): | |
if prompt: | |
input_text = f"{prompt} {text}" | |
else: | |
input_text = f"Expand the following prompt to add more detail: {text}" | |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids | |
outputs = model.generate( | |
input_ids, | |
max_length=max_new_tokens, | |
do_sample=True, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.95, | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
demo = gr.Interface( | |
fn=super_prompt, | |
inputs=["text", "slider", "slider", "text"], | |
outputs=["text"], | |
) | |
demo.launch() |