1inkusFace commited on
Commit
e95c348
·
verified ·
1 Parent(s): 89200e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py CHANGED
@@ -44,6 +44,40 @@ from diffusers import StableDiffusion3Pipeline, SD3Transformer2DModel, Autoencod
44
  from PIL import Image
45
  from image_gen_aux import UpscaleWithModel
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # --- GCS Configuration ---
48
  # Make sure to set these secrets in your Hugging Face Space settings
49
  GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
@@ -78,6 +112,17 @@ def upload_to_gcs(image_object, filename):
78
 
79
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
80
 
 
 
 
 
 
 
 
 
 
 
 
81
  def load_model():
82
  pipe = StableDiffusion3Pipeline.from_pretrained(
83
  "ford442/stable-diffusion-3.5-large-bf16",
@@ -89,11 +134,21 @@ def load_model():
89
  pipe.transformer=ll_transformer
90
  pipe.load_lora_weights("ford442/sdxl-vae-bf16", weight_name="LoRA/UltraReal.safetensors")
91
  pipe.to(device=device, dtype=torch.bfloat16)
 
 
 
 
92
  upscaler_2 = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(device)
 
93
  return pipe, upscaler_2
 
 
94
 
95
  pipe, upscaler_2 = load_model()
96
 
 
 
 
97
  MAX_SEED = np.iinfo(np.int32).max
98
  MAX_IMAGE_SIZE = 4096
99
 
 
44
  from PIL import Image
45
  from image_gen_aux import UpscaleWithModel
46
 
47
+
48
+
49
+ from diffusers.models.attention_processor import Attention
50
+ from kernels import get_kernel
51
+ vllm_flash_attn3 = get_kernel("kernels-community/vllm-flash-attn3")
52
+
53
+ class FlashAttentionProcessor(Attention):
54
+ def __init__(self):
55
+ super().__init__()
56
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs):
57
+ query = attn.to_q(hidden_states)
58
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
59
+ key = attn.to_k(encoder_hidden_states)
60
+ value = attn.to_v(encoder_hidden_states)
61
+ # Scale the queries
62
+ scale = attn.scale
63
+ query = query * scale
64
+ # Reshape to match kernel requirements
65
+ b, t, c = query.shape
66
+ h = attn.heads
67
+ q_reshaped = query.reshape(b, t, h, c // h)
68
+ k_reshaped = key.reshape(b, t, h, c // h)
69
+ v_reshaped = value.reshape(b, t, h, c // h)
70
+ out_reshaped = torch.empty_like(q_reshaped)
71
+ # Call the pre-compiled kernel
72
+ vllm_flash_attn3.attention(q_reshaped, k_reshaped, v_reshaped, out_reshaped)
73
+ # Reshape output back
74
+ out = out_reshaped.reshape(b, t, c)
75
+ out = attn.to_out[0](out)
76
+ out = attn.to_out[1](out)
77
+ return out
78
+
79
+
80
+
81
  # --- GCS Configuration ---
82
  # Make sure to set these secrets in your Hugging Face Space settings
83
  GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
 
112
 
113
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
114
 
115
+ @spaces.GPU(duration=120)
116
+ def compile_transformer():
117
+ with spaces.aoti_capture(pipe.transformer) as call:
118
+ pipe("A majestic, ancient Egyptian Sphinx stands sentinel in a large, clear pool under a bright, golden desert sun. Around its weathered stone base, several sleek, playful dolphins gracefully navigate the turquoise waters. The surrounding environment features lush, exotic papyrus plants and distant pyramids under a cloudless sky, conveying a sense of timeless wonder and serene majesty.")
119
+ exported = torch.export.export(
120
+ pipe.transformer,
121
+ args=call.args,
122
+ kwargs=call.kwargs,
123
+ )
124
+ return spaces.aoti_compile(exported)
125
+
126
  def load_model():
127
  pipe = StableDiffusion3Pipeline.from_pretrained(
128
  "ford442/stable-diffusion-3.5-large-bf16",
 
134
  pipe.transformer=ll_transformer
135
  pipe.load_lora_weights("ford442/sdxl-vae-bf16", weight_name="LoRA/UltraReal.safetensors")
136
  pipe.to(device=device, dtype=torch.bfloat16)
137
+ for name, module in pipe.unet.named_modules():
138
+ if isinstance(module, Attention):
139
+ module.processor = fa_processor
140
+
141
  upscaler_2 = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(device)
142
+
143
  return pipe, upscaler_2
144
+
145
+ fa_processor = FlashAttentionProcessor()
146
 
147
  pipe, upscaler_2 = load_model()
148
 
149
+ compiled_transformer = compile_transformer()
150
+ spaces.aoti_apply(compiled_transformer, pipe.transformer)
151
+
152
  MAX_SEED = np.iinfo(np.int32).max
153
  MAX_IMAGE_SIZE = 4096
154