Spaces:
Running
Running
anbucur
commited on
Commit
·
5d8e518
1
Parent(s):
25e8e9c
Refactor generate_design method in ProductionDesignModel for improved image handling and variation generation
Browse files- Updated the method to accept various image types (PIL Image, numpy array, torch tensor) and ensure proper conversion to RGB format.
- Enhanced parameter handling by consolidating the retrieval of prompt, number of variations, and other settings from kwargs.
- Implemented distinct seed generation for each variation to ensure diversity in outputs.
- Improved error handling and logging for better traceability during the design generation process.
- Cleared CUDA cache after each variation generation to optimize memory usage.
- prod_model.py +54 -55
prod_model.py
CHANGED
@@ -162,83 +162,82 @@ class ProductionDesignModel(DesignModel):
|
|
162 |
if torch.cuda.is_available():
|
163 |
torch.cuda.empty_cache()
|
164 |
|
165 |
-
def generate_design(self, image
|
166 |
-
"""
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
"""
|
169 |
try:
|
170 |
-
#
|
171 |
-
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
# Get parameters
|
176 |
-
|
177 |
-
guidance_scale = float(kwargs.get('guidance_scale', 10.0))
|
178 |
num_steps = int(kwargs.get('num_steps', 50))
|
|
|
179 |
strength = float(kwargs.get('strength', 0.9))
|
180 |
-
|
181 |
-
|
182 |
-
logging.info(f"Generating design with parameters: guidance_scale={guidance_scale}, "
|
183 |
-
f"num_steps={num_steps}, strength={strength}, img_size={img_size}")
|
184 |
-
|
185 |
-
# Prepare prompt
|
186 |
-
pos_prompt = f"{prompt}, {self.additional_quality_suffix}"
|
187 |
-
|
188 |
-
# Process input image
|
189 |
-
orig_size = image.size
|
190 |
-
input_image = self._resize_image(image, img_size)
|
191 |
|
192 |
-
#
|
193 |
-
|
|
|
194 |
|
195 |
-
# Generate segmentation
|
196 |
-
seg_map = self._segment_image(input_image)
|
197 |
-
|
198 |
-
# Generate IP-adapter reference image
|
199 |
-
self._flush()
|
200 |
-
ip_image = self.guide_pipe(
|
201 |
-
pos_prompt,
|
202 |
-
num_inference_steps=num_steps,
|
203 |
-
negative_prompt=self.neg_prompt,
|
204 |
-
generator=self.generator
|
205 |
-
).images[0]
|
206 |
-
|
207 |
-
# Generate variations
|
208 |
variations = []
|
209 |
for i in range(num_variations):
|
210 |
try:
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
215 |
num_inference_steps=num_steps,
|
216 |
-
strength=strength,
|
217 |
guidance_scale=guidance_scale,
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
control_image=[depth_map, seg_map],
|
222 |
-
controlnet_conditioning_scale=[0.5, 0.5]
|
223 |
).images[0]
|
224 |
|
225 |
-
|
226 |
-
|
227 |
-
variations.append(variation)
|
228 |
|
229 |
except Exception as e:
|
230 |
-
logging.error(f"Error generating variation {i}: {e}")
|
231 |
continue
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
233 |
if not variations:
|
234 |
logging.warning("No variations were generated successfully")
|
235 |
-
return [image] # Return original image if no variations
|
236 |
-
|
237 |
return variations
|
238 |
-
|
239 |
except Exception as e:
|
240 |
-
logging.error(f"Error in generate_design: {e}")
|
241 |
-
return [image] # Return original image
|
242 |
|
243 |
def __del__(self):
|
244 |
"""Cleanup when the model is deleted"""
|
|
|
162 |
if torch.cuda.is_available():
|
163 |
torch.cuda.empty_cache()
|
164 |
|
165 |
+
def generate_design(self, image, num_variations=1, **kwargs):
|
166 |
+
"""Generate design variations using the model.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
image: Input image (PIL Image, numpy array, or torch tensor)
|
170 |
+
num_variations: Number of variations to generate
|
171 |
+
**kwargs: Additional parameters like prompt, num_steps, guidance_scale, strength
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
List of generated images
|
175 |
"""
|
176 |
try:
|
177 |
+
# Convert image to PIL Image if needed
|
178 |
+
if isinstance(image, np.ndarray):
|
179 |
+
image = Image.fromarray(image)
|
180 |
+
elif isinstance(image, torch.Tensor):
|
181 |
+
# Convert tensor to numpy then PIL
|
182 |
+
image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8))
|
183 |
+
|
184 |
+
if not isinstance(image, Image.Image):
|
185 |
+
raise ValueError(f"Unsupported image type: {type(image)}")
|
186 |
+
|
187 |
+
# Ensure image is RGB
|
188 |
+
if image.mode != "RGB":
|
189 |
+
image = image.convert("RGB")
|
190 |
|
191 |
# Get parameters
|
192 |
+
prompt = kwargs.get('prompt', '')
|
|
|
193 |
num_steps = int(kwargs.get('num_steps', 50))
|
194 |
+
guidance_scale = float(kwargs.get('guidance_scale', 10.0))
|
195 |
strength = float(kwargs.get('strength', 0.9))
|
196 |
+
seed_param = kwargs.get('seed')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
+
# Handle seed
|
199 |
+
base_seed = int(time.time()) if seed_param is None else int(seed_param)
|
200 |
+
logging.info(f"Using base seed: {base_seed}")
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
variations = []
|
203 |
for i in range(num_variations):
|
204 |
try:
|
205 |
+
# Generate distinct seed for each variation
|
206 |
+
seed = base_seed + i
|
207 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
208 |
+
|
209 |
+
# Generate variation
|
210 |
+
output = self.pipe(
|
211 |
+
prompt=prompt,
|
212 |
+
image=image,
|
213 |
num_inference_steps=num_steps,
|
|
|
214 |
guidance_scale=guidance_scale,
|
215 |
+
strength=strength,
|
216 |
+
generator=generator,
|
217 |
+
negative_prompt=self.neg_prompt
|
|
|
|
|
218 |
).images[0]
|
219 |
|
220 |
+
variations.append(output)
|
221 |
+
logging.info(f"Successfully generated variation {i} with seed {seed}")
|
|
|
222 |
|
223 |
except Exception as e:
|
224 |
+
logging.error(f"Error generating variation {i}: {str(e)}")
|
225 |
continue
|
226 |
+
|
227 |
+
finally:
|
228 |
+
# Clear CUDA cache after each variation
|
229 |
+
if torch.cuda.is_available():
|
230 |
+
torch.cuda.empty_cache()
|
231 |
+
|
232 |
if not variations:
|
233 |
logging.warning("No variations were generated successfully")
|
234 |
+
return [image] # Return original image if no variations generated
|
235 |
+
|
236 |
return variations
|
237 |
+
|
238 |
except Exception as e:
|
239 |
+
logging.error(f"Error in generate_design: {str(e)}")
|
240 |
+
return [image] # Return original image on error
|
241 |
|
242 |
def __del__(self):
|
243 |
"""Cleanup when the model is deleted"""
|