import gradio as gr import torch #from torch import autocast from diffusers import StableDiffusionPipeline #from datasets import load_dataset from PIL import Image import re import os auth_token = os.getenv("auth_token") model_id = "CompVis/stable-diffusion-v1-4" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #device = "cuda" print(device) pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=auth_token, revision="fp16", torch_dtype=torch.float32) pipe = pipe.to(device) stable_diffusion = gr.Blocks.load(name="spaces/stabilityai/stable-diffusion") def get_images(prompt): gallery_dir = stable_diffusion(prompt, fn_index=2) return [os.path.join(gallery_dir, img) for img in os.listdir(gallery_dir)] def infer(prompt, samples, steps, scale, seed): generator = torch.Generator(device=device).manual_seed(seed) images_list = pipe( [prompt] * samples, num_inference_steps=steps, guidance_scale=scale, generator=generator, ) images = [] # safe_image = Image.open(r"unsafe.png") for i, image in enumerate(images_list["sample"]): images.append(image) # if(images_list["nsfw_content_detected"][i]): # images.append(safe_image) # else: # images.append(image) return images block = gr.Blocks() with block: with gr.Group(): with gr.Box(): with gr.Row().style(mobile_collapse=False, equal_height=True): text = gr.Textbox( label="Enter your prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", ).style( border=(True, False, True, True), rounded=(True, False, False, True), container=False, ) btn = gr.Button("Generate image").style( margin=False, rounded=(False, True, True, False), ) submit_button_text = gr.Button("Quick Convert to Image").style( margin=True, rounded=(True, True, True, True), ) gallery = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery" ).style(grid=[2], height="auto") advanced_button = gr.Button("Advanced options", elem_id="advanced-btn") with gr.Row(elem_id="advanced-options"): samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1) steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1) scale = gr.Slider( label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1 ) seed = gr.Slider( label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True, ) text.submit(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery) btn.click(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery) advanced_button.click( None, [], text, ) output = gr.Textbox(label="Output Box") sd_output = gr.Gallery().style(grid=2, height="auto") #text_output.change(fn=get_images, inputs=text_output, outputs=sd_output) submit_button_text.click(fn=get_images, inputs=text, outputs=sd_output) block.launch()