huyai123 commited on
Commit
a44f1cf
·
verified ·
1 Parent(s): edd8452

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -5
handler.py CHANGED
@@ -18,7 +18,7 @@ class EndpointHandler:
18
  logging.basicConfig(level=logging.INFO)
19
  logging.info("Using HF_TOKEN")
20
 
21
- # Load model and pipeline
22
  self.controlnet = FluxControlNetModel.from_pretrained(
23
  model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN
24
  )
@@ -30,6 +30,10 @@ class EndpointHandler:
30
  )
31
  self.pipe.to("cuda")
32
 
 
 
 
 
33
  def preprocess(self, data):
34
  # Load image from file
35
  image_file = data.get("control_image", None)
@@ -37,8 +41,9 @@ class EndpointHandler:
37
  raise ValueError("Missing control_image in input.")
38
  image = Image.open(image_file)
39
  w, h = image.size
40
- # Upscale x4
41
- return image.resize((w * 4, h * 4))
 
42
 
43
  def postprocess(self, output):
44
  # Save output image to a file-like object
@@ -54,8 +59,8 @@ class EndpointHandler:
54
  output_image = self.pipe(
55
  prompt=data.get("prompt", ""),
56
  control_image=control_image,
57
- controlnet_conditioning_scale=0.6,
58
- num_inference_steps=20,
59
  height=control_image.size[1],
60
  width=control_image.size[0],
61
  ).images[0]
 
18
  logging.basicConfig(level=logging.INFO)
19
  logging.info("Using HF_TOKEN")
20
 
21
+ # Load model and pipeline with memory optimizations
22
  self.controlnet = FluxControlNetModel.from_pretrained(
23
  model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN
24
  )
 
30
  )
31
  self.pipe.to("cuda")
32
 
33
+ # Enable memory optimizations
34
+ self.pipe.enable_attention_slicing()
35
+ self.pipe.enable_sequential_cpu_offload()
36
+
37
  def preprocess(self, data):
38
  # Load image from file
39
  image_file = data.get("control_image", None)
 
41
  raise ValueError("Missing control_image in input.")
42
  image = Image.open(image_file)
43
  w, h = image.size
44
+ # Reduce image size for memory efficiency
45
+ image = image.resize((w // 2, h // 2)) # Downscale to save memory
46
+ return image.resize((w, h)) # Upscale back after processing
47
 
48
  def postprocess(self, output):
49
  # Save output image to a file-like object
 
59
  output_image = self.pipe(
60
  prompt=data.get("prompt", ""),
61
  control_image=control_image,
62
+ controlnet_conditioning_scale=0.5, # Slightly reduced for memory efficiency
63
+ num_inference_steps=15, # Reduced steps to save memory
64
  height=control_image.size[1],
65
  width=control_image.size[0],
66
  ).images[0]