Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from torch import autocast | |
from kandinsky2 import get_kandinsky2 | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = get_kandinsky2('cuda', task_type='text2img', cache_dir='/tmp/kandinsky2', model_version='2.1', use_flash_attention=False) | |
def generate_text(prompt, quality="High (Default)"): | |
length_dict = {"Low": 50, "High (Default)": 100, "Ultra": 150} | |
length = length_dict[quality] | |
return model.generate_text2img('''red cat, 4k photo''', num_steps=length, | |
batch_size=1, guidance_scale=4, | |
h=768, w=768 | |
,sampler='p_sampler', prior_cf_scale=4, | |
prior_steps="5",)[0] | |
iface = gr.Interface( | |
fn=generate_text, | |
inputs=["textbox", gr.inputs.Dropdown(["Low", "Medium (Default)", "High"], label="Quality")], | |
outputs=gr.outputs.Image(label="Generated image:") | |
) | |
if device.type == 'cpu': | |
model.load_state_dict(torch.load('path/to/model.pth', map_location=device)) | |
else: | |
model.load_state_dict(torch.load('path/to/model.pth')) | |
iface.launch() | |