Ryukijano commited on
Commit
e5750a9
·
verified ·
1 Parent(s): 89fffa1

Update custom_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +3 -6
custom_pipeline.py CHANGED
@@ -48,7 +48,7 @@ class FluxWithCFGPipeline(FluxPipeline):
48
  with progressively increasing resolution for faster generation.
49
  """
50
  @torch.inference_mode()
51
- def generate_images(
52
  self,
53
  prompt: Union[str, List[str]] = None,
54
  prompt_2: Optional[Union[str, List[str]]] = None,
@@ -64,7 +64,6 @@ class FluxWithCFGPipeline(FluxPipeline):
64
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
65
  output_type: Optional[str] = "pil",
66
  return_dict: bool = True,
67
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
68
  max_sequence_length: int = 300,
69
  generate_with_graph = None
70
  ):
@@ -84,7 +83,6 @@ class FluxWithCFGPipeline(FluxPipeline):
84
  )
85
 
86
  self._guidance_scale = guidance_scale
87
- self._joint_attention_kwargs = joint_attention_kwargs
88
  self._interrupt = False
89
 
90
  # 2. Define call parameters
@@ -92,7 +90,7 @@ class FluxWithCFGPipeline(FluxPipeline):
92
  device = self._execution_device
93
 
94
  # 3. Encode prompt
95
- lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
96
  prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
97
  prompt=prompt,
98
  prompt_2=prompt_2,
@@ -140,7 +138,7 @@ class FluxWithCFGPipeline(FluxPipeline):
140
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
141
 
142
  if generate_with_graph:
143
- return generate_with_graph(latents, prompt_embeds, pooled_prompt_embeds, text_ids, latent_image_ids, timestep)
144
  else:
145
  noise_pred = self.transformer(
146
  hidden_states=latents,
@@ -150,7 +148,6 @@ class FluxWithCFGPipeline(FluxPipeline):
150
  encoder_hidden_states=prompt_embeds,
151
  txt_ids=text_ids,
152
  img_ids=latent_image_ids,
153
- joint_attention_kwargs=self.joint_attention_kwargs,
154
  return_dict=False,
155
  )[0]
156
 
 
48
  with progressively increasing resolution for faster generation.
49
  """
50
  @torch.inference_mode()
51
+ async def generate_images(
52
  self,
53
  prompt: Union[str, List[str]] = None,
54
  prompt_2: Optional[Union[str, List[str]]] = None,
 
64
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
65
  output_type: Optional[str] = "pil",
66
  return_dict: bool = True,
 
67
  max_sequence_length: int = 300,
68
  generate_with_graph = None
69
  ):
 
83
  )
84
 
85
  self._guidance_scale = guidance_scale
 
86
  self._interrupt = False
87
 
88
  # 2. Define call parameters
 
90
  device = self._execution_device
91
 
92
  # 3. Encode prompt
93
+ lora_scale = None
94
  prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
95
  prompt=prompt,
96
  prompt_2=prompt_2,
 
138
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
139
 
140
  if generate_with_graph:
141
+ return await generate_with_graph(latents, prompt_embeds, pooled_prompt_embeds, text_ids, latent_image_ids, timestep)
142
  else:
143
  noise_pred = self.transformer(
144
  hidden_states=latents,
 
148
  encoder_hidden_states=prompt_embeds,
149
  txt_ids=text_ids,
150
  img_ids=latent_image_ids,
 
151
  return_dict=False,
152
  )[0]
153