Diffusers
3v324v23 commited on
Commit
631a82f
·
1 Parent(s): 2c58008

New handler

Browse files
Files changed (1) hide show
  1. handler.py +10 -8
handler.py CHANGED
@@ -87,9 +87,7 @@ class EndpointHandler:
87
  self.total_steps = {}
88
  self.inference_in_progress = False
89
 
90
- self.executor = ThreadPoolExecutor(
91
- max_workers=1
92
- ) # Vous pouvez ajuster max_workers en fonction de vos besoins
93
 
94
  # load the optimized model
95
  self.pipe = DiffusionPipeline.from_pretrained(
@@ -226,13 +224,13 @@ class EndpointHandler:
226
  """Clean up the data related to a specific request ID."""
227
 
228
  # Remove the request ID from the progress dictionary
229
- self.inference_progress.pop(request_id, None)
230
 
231
  # Remove the request ID from the images dictionary
232
- self.inference_images.pop(request_id, None)
233
 
234
  # Remove the request ID from the total_steps dictionary
235
- self.total_steps.pop(request_id, None)
236
 
237
  # Set inference to False
238
  self.inference_in_progress = False
@@ -269,6 +267,7 @@ class EndpointHandler:
269
 
270
  except Exception as e:
271
  print(f"Error: {e}")
 
272
 
273
  # Store progress and image
274
  progress_percentage = (
@@ -278,7 +277,9 @@ class EndpointHandler:
278
  self.inference_progress[request_id] = progress_percentage
279
  self.inference_images[request_id] = img_str
280
 
281
- def check_progress(self, request_id: str) -> Dict[str, Union[str, float]]:
 
 
282
  progress = self.inference_progress.get(request_id, 0)
283
  latest_image = self.inference_images.get(request_id, None)
284
 
@@ -387,7 +388,8 @@ class EndpointHandler:
387
 
388
  except Exception as e:
389
  # Handle any other exceptions and return an error response
390
- return {"flag": "error", "message": str(e)}
 
391
 
392
  def __call__(self, data: Any) -> Dict:
393
  """Handle incoming requests."""
 
87
  self.total_steps = {}
88
  self.inference_in_progress = False
89
 
90
+ self.executor = ThreadPoolExecutor(max_workers=1)
 
 
91
 
92
  # load the optimized model
93
  self.pipe = DiffusionPipeline.from_pretrained(
 
224
  """Clean up the data related to a specific request ID."""
225
 
226
  # Remove the request ID from the progress dictionary
227
+ self.inference_progress.clear()
228
 
229
  # Remove the request ID from the images dictionary
230
+ self.inference_images.clear()
231
 
232
  # Remove the request ID from the total_steps dictionary
233
+ self.total_steps.clear()
234
 
235
  # Set inference to False
236
  self.inference_in_progress = False
 
267
 
268
  except Exception as e:
269
  print(f"Error: {e}")
270
+ raise
271
 
272
  # Store progress and image
273
  progress_percentage = (
 
277
  self.inference_progress[request_id] = progress_percentage
278
  self.inference_images[request_id] = img_str
279
 
280
+ print(self.inference_progress)
281
+
282
+ def check_progress(self, request_id: str) -> dict[str, str | float]:
283
  progress = self.inference_progress.get(request_id, 0)
284
  latest_image = self.inference_images.get(request_id, None)
285
 
 
388
 
389
  except Exception as e:
390
  # Handle any other exceptions and return an error response
391
+ print(f"Error in start inference {e}")
392
+ raise
393
 
394
  def __call__(self, data: Any) -> Dict:
395
  """Handle incoming requests."""