Update handler.py
Browse files- handler.py +10 -5
handler.py
CHANGED
|
@@ -18,7 +18,7 @@ class EndpointHandler:
|
|
| 18 |
logging.basicConfig(level=logging.INFO)
|
| 19 |
logging.info("Using HF_TOKEN")
|
| 20 |
|
| 21 |
-
# Load model and pipeline
|
| 22 |
self.controlnet = FluxControlNetModel.from_pretrained(
|
| 23 |
model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN
|
| 24 |
)
|
|
@@ -30,6 +30,10 @@ class EndpointHandler:
|
|
| 30 |
)
|
| 31 |
self.pipe.to("cuda")
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def preprocess(self, data):
|
| 34 |
# Load image from file
|
| 35 |
image_file = data.get("control_image", None)
|
|
@@ -37,8 +41,9 @@ class EndpointHandler:
|
|
| 37 |
raise ValueError("Missing control_image in input.")
|
| 38 |
image = Image.open(image_file)
|
| 39 |
w, h = image.size
|
| 40 |
-
#
|
| 41 |
-
|
|
|
|
| 42 |
|
| 43 |
def postprocess(self, output):
|
| 44 |
# Save output image to a file-like object
|
|
@@ -54,8 +59,8 @@ class EndpointHandler:
|
|
| 54 |
output_image = self.pipe(
|
| 55 |
prompt=data.get("prompt", ""),
|
| 56 |
control_image=control_image,
|
| 57 |
-
controlnet_conditioning_scale=0.
|
| 58 |
-
num_inference_steps=
|
| 59 |
height=control_image.size[1],
|
| 60 |
width=control_image.size[0],
|
| 61 |
).images[0]
|
|
|
|
| 18 |
logging.basicConfig(level=logging.INFO)
|
| 19 |
logging.info("Using HF_TOKEN")
|
| 20 |
|
| 21 |
+
# Load model and pipeline with memory optimizations
|
| 22 |
self.controlnet = FluxControlNetModel.from_pretrained(
|
| 23 |
model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN
|
| 24 |
)
|
|
|
|
| 30 |
)
|
| 31 |
self.pipe.to("cuda")
|
| 32 |
|
| 33 |
+
# Enable memory optimizations
|
| 34 |
+
self.pipe.enable_attention_slicing()
|
| 35 |
+
self.pipe.enable_sequential_cpu_offload()
|
| 36 |
+
|
| 37 |
def preprocess(self, data):
|
| 38 |
# Load image from file
|
| 39 |
image_file = data.get("control_image", None)
|
|
|
|
| 41 |
raise ValueError("Missing control_image in input.")
|
| 42 |
image = Image.open(image_file)
|
| 43 |
w, h = image.size
|
| 44 |
+
# Reduce image size for memory efficiency
|
| 45 |
+
image = image.resize((w // 2, h // 2)) # Downscale to save memory
|
| 46 |
+
return image.resize((w, h)) # Upscale back after processing
|
| 47 |
|
| 48 |
def postprocess(self, output):
|
| 49 |
# Save output image to a file-like object
|
|
|
|
| 59 |
output_image = self.pipe(
|
| 60 |
prompt=data.get("prompt", ""),
|
| 61 |
control_image=control_image,
|
| 62 |
+
controlnet_conditioning_scale=0.5, # Slightly reduced for memory efficiency
|
| 63 |
+
num_inference_steps=15, # Reduced steps to save memory
|
| 64 |
height=control_image.size[1],
|
| 65 |
width=control_image.size[0],
|
| 66 |
).images[0]
|