import functools as ft import gradio as gr import torch import transformers from transformers import T5ForConditionalGeneration, T5Tokenizer tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( "models/roborovski/superprompt-v1", local_files_only=True ) model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained( "models/roborovski/superprompt-v1", local_files_only=True ) @ft.lru_cache(maxsize=1024) def super_prompt(text: str, seed: int, max_new_tokens: int, prompt: str) -> str: transformers.set_seed(seed) if max_new_tokens <= 0: max_new_tokens = 150 with torch.inference_mode(): if prompt: input_text = f"{prompt} {text}" else: input_text = f"Expand the following prompt to add more detail: {text}" input_ids = tokenizer(input_text, return_tensors="pt").input_ids outputs = model.generate( input_ids, max_length=max_new_tokens, do_sample=True, temperature=0.7, top_k=50, top_p=0.95, ) return tokenizer.decode(outputs[0], skip_special_tokens=True) demo = gr.Interface( fn=super_prompt, inputs=["text", "slider", "slider", "text"], outputs=["text"], ) demo.launch()