anbucur commited on
Commit
5d8e518
·
1 Parent(s): 25e8e9c

Refactor generate_design method in ProductionDesignModel for improved image handling and variation generation

Browse files

- Updated the method to accept various image types (PIL Image, numpy array, torch tensor) and ensure proper conversion to RGB format.
- Enhanced parameter handling by consolidating the retrieval of prompt, number of variations, and other settings from kwargs.
- Implemented distinct seed generation for each variation to ensure diversity in outputs.
- Improved error handling and logging for better traceability during the design generation process.
- Cleared CUDA cache after each variation generation to optimize memory usage.

Files changed (1) hide show
  1. prod_model.py +54 -55
prod_model.py CHANGED
@@ -162,83 +162,82 @@ class ProductionDesignModel(DesignModel):
162
  if torch.cuda.is_available():
163
  torch.cuda.empty_cache()
164
 
165
- def generate_design(self, image: Image.Image, prompt: str, **kwargs) -> List[Image.Image]:
166
- """
167
- Generate design variations based on input image and prompt
 
 
 
 
 
 
 
168
  """
169
  try:
170
- # Set seed
171
- seed_param = kwargs.get('seed')
172
- base_seed = int(time.time()) if seed_param is None else int(seed_param)
173
- self.generator = torch.Generator(device=self.device).manual_seed(base_seed)
 
 
 
 
 
 
 
 
 
174
 
175
  # Get parameters
176
- num_variations = kwargs.get('num_variations', 1)
177
- guidance_scale = float(kwargs.get('guidance_scale', 10.0))
178
  num_steps = int(kwargs.get('num_steps', 50))
 
179
  strength = float(kwargs.get('strength', 0.9))
180
- img_size = int(kwargs.get('img_size', 768))
181
-
182
- logging.info(f"Generating design with parameters: guidance_scale={guidance_scale}, "
183
- f"num_steps={num_steps}, strength={strength}, img_size={img_size}")
184
-
185
- # Prepare prompt
186
- pos_prompt = f"{prompt}, {self.additional_quality_suffix}"
187
-
188
- # Process input image
189
- orig_size = image.size
190
- input_image = self._resize_image(image, img_size)
191
 
192
- # Generate depth map
193
- depth_map = self._get_depth_map(input_image)
 
194
 
195
- # Generate segmentation
196
- seg_map = self._segment_image(input_image)
197
-
198
- # Generate IP-adapter reference image
199
- self._flush()
200
- ip_image = self.guide_pipe(
201
- pos_prompt,
202
- num_inference_steps=num_steps,
203
- negative_prompt=self.neg_prompt,
204
- generator=self.generator
205
- ).images[0]
206
-
207
- # Generate variations
208
  variations = []
209
  for i in range(num_variations):
210
  try:
211
- self._flush()
212
- variation = self.pipe(
213
- prompt=pos_prompt,
214
- negative_prompt=self.neg_prompt,
 
 
 
 
215
  num_inference_steps=num_steps,
216
- strength=strength,
217
  guidance_scale=guidance_scale,
218
- generator=self.generator,
219
- image=input_image,
220
- ip_adapter_image=ip_image,
221
- control_image=[depth_map, seg_map],
222
- controlnet_conditioning_scale=[0.5, 0.5]
223
  ).images[0]
224
 
225
- # Resize back to original size
226
- variation = variation.resize(orig_size, Image.LANCZOS)
227
- variations.append(variation)
228
 
229
  except Exception as e:
230
- logging.error(f"Error generating variation {i}: {e}")
231
  continue
232
-
 
 
 
 
 
233
  if not variations:
234
  logging.warning("No variations were generated successfully")
235
- return [image] # Return original image if no variations were generated
236
-
237
  return variations
238
-
239
  except Exception as e:
240
- logging.error(f"Error in generate_design: {e}")
241
- return [image] # Return original image in case of error
242
 
243
  def __del__(self):
244
  """Cleanup when the model is deleted"""
 
162
  if torch.cuda.is_available():
163
  torch.cuda.empty_cache()
164
 
165
+ def generate_design(self, image, num_variations=1, **kwargs):
166
+ """Generate design variations using the model.
167
+
168
+ Args:
169
+ image: Input image (PIL Image, numpy array, or torch tensor)
170
+ num_variations: Number of variations to generate
171
+ **kwargs: Additional parameters like prompt, num_steps, guidance_scale, strength
172
+
173
+ Returns:
174
+ List of generated images
175
  """
176
  try:
177
+ # Convert image to PIL Image if needed
178
+ if isinstance(image, np.ndarray):
179
+ image = Image.fromarray(image)
180
+ elif isinstance(image, torch.Tensor):
181
+ # Convert tensor to numpy then PIL
182
+ image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8))
183
+
184
+ if not isinstance(image, Image.Image):
185
+ raise ValueError(f"Unsupported image type: {type(image)}")
186
+
187
+ # Ensure image is RGB
188
+ if image.mode != "RGB":
189
+ image = image.convert("RGB")
190
 
191
  # Get parameters
192
+ prompt = kwargs.get('prompt', '')
 
193
  num_steps = int(kwargs.get('num_steps', 50))
194
+ guidance_scale = float(kwargs.get('guidance_scale', 10.0))
195
  strength = float(kwargs.get('strength', 0.9))
196
+ seed_param = kwargs.get('seed')
 
 
 
 
 
 
 
 
 
 
197
 
198
+ # Handle seed
199
+ base_seed = int(time.time()) if seed_param is None else int(seed_param)
200
+ logging.info(f"Using base seed: {base_seed}")
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  variations = []
203
  for i in range(num_variations):
204
  try:
205
+ # Generate distinct seed for each variation
206
+ seed = base_seed + i
207
+ generator = torch.Generator(device=self.device).manual_seed(seed)
208
+
209
+ # Generate variation
210
+ output = self.pipe(
211
+ prompt=prompt,
212
+ image=image,
213
  num_inference_steps=num_steps,
 
214
  guidance_scale=guidance_scale,
215
+ strength=strength,
216
+ generator=generator,
217
+ negative_prompt=self.neg_prompt
 
 
218
  ).images[0]
219
 
220
+ variations.append(output)
221
+ logging.info(f"Successfully generated variation {i} with seed {seed}")
 
222
 
223
  except Exception as e:
224
+ logging.error(f"Error generating variation {i}: {str(e)}")
225
  continue
226
+
227
+ finally:
228
+ # Clear CUDA cache after each variation
229
+ if torch.cuda.is_available():
230
+ torch.cuda.empty_cache()
231
+
232
  if not variations:
233
  logging.warning("No variations were generated successfully")
234
+ return [image] # Return original image if no variations generated
235
+
236
  return variations
237
+
238
  except Exception as e:
239
+ logging.error(f"Error in generate_design: {str(e)}")
240
+ return [image] # Return original image on error
241
 
242
  def __del__(self):
243
  """Cleanup when the model is deleted"""