Ryukijano commited on
Commit
105e0dd
·
verified ·
1 Parent(s): 16c45c8

Update custom_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +18 -19
custom_pipeline.py CHANGED
@@ -64,8 +64,8 @@ class FluxWithCFGPipeline(FluxPipeline):
64
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
65
  output_type: Optional[str] = "pil",
66
  return_dict: bool = True,
 
67
  max_sequence_length: int = 300,
68
- generate_with_graph = None
69
  ):
70
  """Generates images and yields intermediate results during the denoising process."""
71
  height = height or self.default_sample_size * self.vae_scale_factor
@@ -83,6 +83,7 @@ class FluxWithCFGPipeline(FluxPipeline):
83
  )
84
 
85
  self._guidance_scale = guidance_scale
 
86
  self._interrupt = False
87
 
88
  # 2. Define call parameters
@@ -90,7 +91,7 @@ class FluxWithCFGPipeline(FluxPipeline):
90
  device = self._execution_device
91
 
92
  # 3. Encode prompt
93
- lora_scale = None
94
  prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
95
  prompt=prompt,
96
  prompt_2=prompt_2,
@@ -137,23 +138,21 @@ class FluxWithCFGPipeline(FluxPipeline):
137
 
138
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
139
 
140
- if generate_with_graph:
141
- return generate_with_graph(latents, prompt_embeds, pooled_prompt_embeds, text_ids, latent_image_ids, timestep)
142
- else:
143
- noise_pred = self.transformer(
144
- hidden_states=latents,
145
- timestep=timestep / 1000,
146
- guidance=guidance,
147
- pooled_projections=pooled_prompt_embeds,
148
- encoder_hidden_states=prompt_embeds,
149
- txt_ids=text_ids,
150
- img_ids=latent_image_ids,
151
- return_dict=False,
152
- )[0]
153
-
154
- # Yield intermediate result
155
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
156
- torch.cuda.empty_cache()
157
 
158
  # Final image
159
  return self._decode_latents_to_image(latents, height, width, output_type)
 
64
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
65
  output_type: Optional[str] = "pil",
66
  return_dict: bool = True,
67
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
68
  max_sequence_length: int = 300,
 
69
  ):
70
  """Generates images and yields intermediate results during the denoising process."""
71
  height = height or self.default_sample_size * self.vae_scale_factor
 
83
  )
84
 
85
  self._guidance_scale = guidance_scale
86
+ self._joint_attention_kwargs = joint_attention_kwargs
87
  self._interrupt = False
88
 
89
  # 2. Define call parameters
 
91
  device = self._execution_device
92
 
93
  # 3. Encode prompt
94
+ lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
95
  prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
96
  prompt=prompt,
97
  prompt_2=prompt_2,
 
138
 
139
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
140
 
141
+ noise_pred = self.transformer(
142
+ hidden_states=latents,
143
+ timestep=timestep / 1000,
144
+ guidance=guidance,
145
+ pooled_projections=pooled_prompt_embeds,
146
+ encoder_hidden_states=prompt_embeds,
147
+ txt_ids=text_ids,
148
+ img_ids=latent_image_ids,
149
+ joint_attention_kwargs=self.joint_attention_kwargs,
150
+ return_dict=False,
151
+ )[0]
152
+
153
+ # Yield intermediate result
154
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
155
+ torch.cuda.empty_cache()
 
 
156
 
157
  # Final image
158
  return self._decode_latents_to_image(latents, height, width, output_type)