Jordan Legg commited on
Commit
1c4aefd
Β·
1 Parent(s): 3be64a5

removed img2img :(

Browse files
Files changed (1) hide show
  1. app.py +26 -110
app.py CHANGED
@@ -1,145 +1,61 @@
1
  import spaces
2
  import gradio as gr
3
- import numpy as np
4
- import random
5
  import torch
6
- import torch.nn as nn
7
- from PIL import Image
8
- from torchvision import transforms
9
  from diffusers import DiffusionPipeline
10
 
11
  # Constants
12
- dtype = torch.bfloat16
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
- MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 2048
16
- LATENT_CHANNELS = 16
17
- TEXT_EMBED_DIM = 768
18
- MAX_TEXT_EMBEDDINGS = 77
19
- SCALING_FACTOR = 0.3611
20
 
21
  # Load FLUX model
22
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
 
23
  pipe.enable_model_cpu_offload()
24
  pipe.vae.enable_slicing()
25
  pipe.vae.enable_tiling()
26
 
27
- # Add a projection layer to match text embedding dimension
28
- projection = nn.Linear(LATENT_CHANNELS, TEXT_EMBED_DIM).to(device).to(dtype)
29
-
30
- def preprocess_image(image, image_size):
31
- preprocess = transforms.Compose([
32
- transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
33
- transforms.ToTensor(),
34
- transforms.Normalize([0.5], [0.5])
35
- ])
36
- image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
37
- return image
38
-
39
- def process_latents(latents, height, width):
40
- print(f"Input latent shape: {latents.shape}")
41
-
42
- # Ensure latents are the correct shape
43
- if latents.shape[2:] != (height // 8, width // 8):
44
- latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
45
- print(f"Latent shape after potential interpolation: {latents.shape}")
46
-
47
- # Reshape latents to [batch_size, seq_len, channels]
48
- latents = latents.permute(0, 2, 3, 1).reshape(1, -1, LATENT_CHANNELS)
49
- print(f"Reshaped latent shape: {latents.shape}")
50
-
51
- # Project latents to match text embedding dimension
52
- latents = projection(latents)
53
- print(f"Projected latent shape: {latents.shape}")
54
-
55
- # Adjust sequence length to match text embeddings
56
- seq_len = latents.shape[1]
57
- if seq_len > MAX_TEXT_EMBEDDINGS:
58
- latents = latents[:, :MAX_TEXT_EMBEDDINGS, :]
59
- elif seq_len < MAX_TEXT_EMBEDDINGS:
60
- pad_len = MAX_TEXT_EMBEDDINGS - seq_len
61
- latents = torch.nn.functional.pad(latents, (0, 0, 0, pad_len, 0, 0))
62
- print(f"Final latent shape: {latents.shape}")
63
-
64
- return latents
65
-
66
  @spaces.GPU()
67
- def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
68
- if randomize_seed:
69
- seed = random.randint(0, MAX_SEED)
70
- generator = torch.Generator(device=device).manual_seed(seed)
71
-
72
  try:
73
- if init_image is None:
74
- # text2img case
75
- image = pipe(
76
- prompt=prompt,
77
- height=height,
78
- width=width,
79
- num_inference_steps=num_inference_steps,
80
- generator=generator,
81
- guidance_scale=0.0
82
- ).images[0]
83
- else:
84
- # img2img case
85
- init_image = init_image.convert("RGB")
86
- init_image = preprocess_image(init_image, 1024) # Using 1024 as FLUX VAE sample size
87
-
88
- # Encode the image using FLUX VAE
89
- latents = pipe.vae.encode(init_image).latent_dist.sample() * SCALING_FACTOR
90
- print(f"Initial latent shape from VAE: {latents.shape}")
91
-
92
- # Process latents to match text embedding format
93
- latents = process_latents(latents, height, width)
94
-
95
- # Get text embeddings
96
- text_embeddings = pipe.transformer.text_encoder([prompt])
97
- print(f"Text embedding shape: {text_embeddings.shape}")
98
-
99
- # Combine image latents and text embeddings
100
- combined_embeddings = torch.cat([latents, text_embeddings], dim=1)
101
- print(f"Combined embedding shape: {combined_embeddings.shape}")
102
-
103
- image = pipe(
104
- prompt=prompt,
105
- height=height,
106
- width=width,
107
- num_inference_steps=num_inference_steps,
108
- generator=generator,
109
- guidance_scale=0.0,
110
- latents=combined_embeddings
111
- ).images[0]
112
-
113
  return image, seed
114
  except Exception as e:
115
- print(f"Error during inference: {e}")
116
  import traceback
117
  traceback.print_exc()
118
- return Image.new("RGB", (width, height), (255, 0, 0)), seed # Red fallback image
119
 
120
- # Gradio interface setup
121
  with gr.Blocks() as demo:
122
  with gr.Row():
123
  prompt = gr.Textbox(label="Prompt")
124
- init_image = gr.Image(label="Initial Image (optional)", type="pil")
125
-
126
  with gr.Row():
127
  generate = gr.Button("Generate")
128
-
129
  with gr.Row():
130
- result = gr.Image(label="Result")
131
- seed_output = gr.Number(label="Seed")
132
-
133
  with gr.Accordion("Advanced Settings", open=False):
134
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
135
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
136
  width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
137
  height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
138
  num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
139
 
140
  generate.click(
141
- infer,
142
- inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
143
  outputs=[result, seed_output]
144
  )
145
 
 
1
  import spaces
2
  import gradio as gr
 
 
3
  import torch
 
 
 
4
  from diffusers import DiffusionPipeline
5
 
6
  # Constants
7
+ MAX_SEED = 2**32 - 1
 
 
8
  MAX_IMAGE_SIZE = 2048
 
 
 
 
9
 
10
  # Load FLUX model
11
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
12
+ pipe = pipe.to("cuda")
13
  pipe.enable_model_cpu_offload()
14
  pipe.vae.enable_slicing()
15
  pipe.vae.enable_tiling()
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  @spaces.GPU()
18
+ def generate_image(prompt, seed, width, height, num_inference_steps):
19
+ generator = torch.Generator(device="cuda").manual_seed(seed)
20
+
 
 
21
  try:
22
+ image = pipe(
23
+ prompt=prompt,
24
+ height=height,
25
+ width=width,
26
+ num_inference_steps=num_inference_steps,
27
+ generator=generator,
28
+ guidance_scale=0.0
29
+ ).images[0]
30
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  return image, seed
32
  except Exception as e:
33
+ print(f"Error during image generation: {e}")
34
  import traceback
35
  traceback.print_exc()
36
+ return None, seed
37
 
38
+ # Gradio interface
39
  with gr.Blocks() as demo:
40
  with gr.Row():
41
  prompt = gr.Textbox(label="Prompt")
42
+
 
43
  with gr.Row():
44
  generate = gr.Button("Generate")
45
+
46
  with gr.Row():
47
+ result = gr.Image(label="Generated Image")
48
+ seed_output = gr.Number(label="Seed Used")
49
+
50
  with gr.Accordion("Advanced Settings", open=False):
51
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, randomize=True)
 
52
  width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
53
  height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
54
  num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
55
 
56
  generate.click(
57
+ generate_image,
58
+ inputs=[prompt, seed, width, height, num_inference_steps],
59
  outputs=[result, seed_output]
60
  )
61