alexkueck commited on
Commit
5d0a205
·
1 Parent(s): 07a9c58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -19
app.py CHANGED
@@ -1,36 +1,51 @@
1
  import gradio as gr
2
  from diffusers import DiffusionPipeline
3
- import random
 
 
 
 
4
 
5
 
6
- #########################
7
  #Alternativ erzeugen
8
  #gr.Interface.load("models/stabilityai/stable-diffusion-2").launch()
9
 
10
- pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
11
 
12
- def erzeuge():
 
 
13
  pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
14
- images = [
15
- (random.choice(
16
- [
17
- "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
18
- "https://images.unsplash.com/photo-1554151228-14d9def656e4?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=386&q=80",
19
- "https://images.unsplash.com/photo-1542909168-82c3e7fdca5c?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8aHVtYW4lMjBmYWNlfGVufDB8fDB8fA%3D%3D&w=1000&q=80",
20
- "https://images.unsplash.com/photo-1546456073-92b9f0a8d413?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
21
- "https://images.unsplash.com/photo-1601412436009-d964bd02edbc?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=464&q=80",
22
- ]
23
- ), f"label {i}" if i != 0 else "label" * 50)
24
- for i in range(3)
25
- ]
26
- return images
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  with gr.Blocks() as demo:
31
  with gr.Column(variant="panel"):
32
  with gr.Row(variant="compact"):
33
- text = gr.Textbox(
34
  label="Deine Beschreibung:",
35
  show_label=False,
36
  max_lines=1,
@@ -44,7 +59,7 @@ with gr.Blocks() as demo:
44
  label="Erzeugte Bilder", show_label=False, elem_id="gallery"
45
  ).style(columns=[2], rows=[2], object_fit="contain", height="auto")
46
 
47
- btn.click(erzeuge, None, gallery)
48
 
49
  if __name__ == "__main__":
50
  demo.launch()
 
1
  import gradio as gr
2
  from diffusers import DiffusionPipeline
3
+ import torch
4
+ #für die komplexere Variante der Erzeugung
5
+ #from diffusers import DDPMScheduler, UNet2DModel
6
+ #from PIL import Image
7
+ #import numpy as np
8
 
9
 
10
+ #######################################
11
  #Alternativ erzeugen
12
  #gr.Interface.load("models/stabilityai/stable-diffusion-2").launch()
13
 
 
14
 
15
+ #######################################
16
+ #Bild nach dem eingegebenen prompt erzeugen - mit Pipeline
17
+ def erzeuge(prompt):
18
  pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
19
+ pipeline.to("cuda")
20
+ return pipeline(prompt).images[0]
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ ########################################
23
+ #Bild erzeugen - nich über Pipeline sondern mit mehr Einstellungsmöglichkeiten
24
+ def erzeuge_komplex(prompt):
25
+ scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256")
26
+ model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda")
27
+ scheduler.set_timesteps(50)
28
+
29
+ sample_size = model.config.sample_size
30
+ noise = torch.randn((1, 3, sample_size, sample_size)).to("cuda")
31
+ input = noise
32
+
33
+ for t in scheduler.timesteps:
34
+ with torch.no_grad():
35
+ noisy_residual = model(input, t).sample
36
+ prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
37
+ input = prev_noisy_sample
38
+
39
+ image = (input / 2 + 0.5).clamp(0, 1)
40
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
41
+ image = Image.fromarray((image * 255).round().astype("uint8"))
42
+ return image
43
 
44
 
45
  with gr.Blocks() as demo:
46
  with gr.Column(variant="panel"):
47
  with gr.Row(variant="compact"):
48
+ user_input = gr.Textbox(
49
  label="Deine Beschreibung:",
50
  show_label=False,
51
  max_lines=1,
 
59
  label="Erzeugte Bilder", show_label=False, elem_id="gallery"
60
  ).style(columns=[2], rows=[2], object_fit="contain", height="auto")
61
 
62
+ btn.click(erzeuge, inputs=[user_input], gallery)
63
 
64
  if __name__ == "__main__":
65
  demo.launch()