File size: 3,626 Bytes
33c29ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18b055d
33c29ab
 
 
 
 
 
 
 
7700f59
 
20e485e
70c0209
20e485e
33c29ab
2c48ac2
 
 
 
 
 
33c29ab
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from diffusers import StableDiffusionPipeline
import gradio as gr
import torch

models = [
  "DGSpitzer/Cyberpunk-Anime-Diffusion"
]

prompt_prefixes = {
  models[0]: "cyberpunk anime style"
}

current_model = models[0]
pipe = StableDiffusionPipeline.from_pretrained(current_model, torch_dtype=torch.float16)
if torch.cuda.is_available():
  pipe = pipe.to("cuda")

device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"

def on_model_change(model):

    global current_model
    global pipe
    if model != current_model:
        current_model = model
        pipe = StableDiffusionPipeline.from_pretrained(current_model, torch_dtype=torch.float16)
        if torch.cuda.is_available():
            pipe = pipe.to("cuda")

def inference(prompt, guidance, steps):

    prompt = prompt_prefixes[current_model] + prompt
    image = pipe(prompt, num_inference_steps=int(steps), guidance_scale=guidance, width=512, height=512).images[0]
    return image

with gr.Blocks() as demo:
    gr.HTML(
        """
            <div style="text-align: center; max-width: 700px; margin: 0 auto;">
              <div
                style="
                  display: inline-flex;
                  align-items: center;
                  gap: 0.8rem;
                  font-size: 1.75rem;
                "
              >
                <h1 style="font-weight: 900; margin-bottom: 7px;">
                  DGS Diffusion Space
                </h1>
              </div>
              <p style="margin-bottom: 10px; font-size: 94%">
               Demo for Cyberpunk Anime Diffusion. Based of Finetuned Diffusion by anzorq <a href="https://twitter.com/hahahahohohe">
              </p>
            </div>
        """
    )
    with gr.Row():
        
        with gr.Column():
            model = gr.Dropdown(label="Model", choices=models, value=models[0])
            prompt = gr.Textbox(label="Prompt", placeholder="{} is added automatically".format(prompt_prefixes[current_model]))
            guidance = gr.Slider(label="Guidance scale", value=7.5, maximum=15)
            steps = gr.Slider(label="Steps", value=27, maximum=100, minimum=2)
            run = gr.Button(value="Run")
            gr.Markdown(f"Running on: {device}")
        with gr.Column():
            image_out = gr.Image(height=512)

    model.change(on_model_change, inputs=model, outputs=[])
    run.click(inference, inputs=[prompt, guidance, steps], outputs=image_out)
    gr.Examples([
        ["a beautiful perfect face girl in dgs illustration style, Anime fine details portrait of school girl in front of modern tokyo city landscape on the background deep bokeh, anime masterpiece by studio ghibli, 8k, sharp high quality anime, artstation", 7.5, 27],
        ["dgs landscape with fancy car", 7.5, 27],
        ["portrait of liu yifei girl in dgs illustration style,  soldier working in a cyberpunk city, cleavage, intricate, 8k, highly detailed, digital painting, intense, sharp focus", 7.5, 27],
        ["portrait of  in dgs illustration style,  soldier working in a  cyberpunk city, cleavage, intricate, 8k, highly detailed, digital painting, intense, sharp focus", 
  7.5, 27],
    ], [prompt, guidance, steps], image_out, inference, cache_examples=torch.cuda.is_available())
    gr.Markdown('''
      Models and Space by [@DGSpitzer](https://huggingface.co/DGSpitzer)❤️<br>
      [![Twitter Follow](https://img.shields.io/twitter/follow/DGSpitzer?label=%40DGSpitzer&style=social)](https://twitter.com/DGSpitzer)

      ![visitors](https://visitor-badge.glitch.me/badge?page_id=dgspitzer_DGS_Diffusion_Space)
    ''')

demo.queue()
demo.launch()