Update handler.py
Browse files- handler.py +10 -5
handler.py
CHANGED
@@ -18,7 +18,7 @@ class EndpointHandler:
|
|
18 |
logging.basicConfig(level=logging.INFO)
|
19 |
logging.info("Using HF_TOKEN")
|
20 |
|
21 |
-
# Load model and pipeline
|
22 |
self.controlnet = FluxControlNetModel.from_pretrained(
|
23 |
model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN
|
24 |
)
|
@@ -30,6 +30,10 @@ class EndpointHandler:
|
|
30 |
)
|
31 |
self.pipe.to("cuda")
|
32 |
|
|
|
|
|
|
|
|
|
33 |
def preprocess(self, data):
|
34 |
# Load image from file
|
35 |
image_file = data.get("control_image", None)
|
@@ -37,8 +41,9 @@ class EndpointHandler:
|
|
37 |
raise ValueError("Missing control_image in input.")
|
38 |
image = Image.open(image_file)
|
39 |
w, h = image.size
|
40 |
-
#
|
41 |
-
|
|
|
42 |
|
43 |
def postprocess(self, output):
|
44 |
# Save output image to a file-like object
|
@@ -54,8 +59,8 @@ class EndpointHandler:
|
|
54 |
output_image = self.pipe(
|
55 |
prompt=data.get("prompt", ""),
|
56 |
control_image=control_image,
|
57 |
-
controlnet_conditioning_scale=0.
|
58 |
-
num_inference_steps=
|
59 |
height=control_image.size[1],
|
60 |
width=control_image.size[0],
|
61 |
).images[0]
|
|
|
18 |
logging.basicConfig(level=logging.INFO)
|
19 |
logging.info("Using HF_TOKEN")
|
20 |
|
21 |
+
# Load model and pipeline with memory optimizations
|
22 |
self.controlnet = FluxControlNetModel.from_pretrained(
|
23 |
model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN
|
24 |
)
|
|
|
30 |
)
|
31 |
self.pipe.to("cuda")
|
32 |
|
33 |
+
# Enable memory optimizations
|
34 |
+
self.pipe.enable_attention_slicing()
|
35 |
+
self.pipe.enable_sequential_cpu_offload()
|
36 |
+
|
37 |
def preprocess(self, data):
|
38 |
# Load image from file
|
39 |
image_file = data.get("control_image", None)
|
|
|
41 |
raise ValueError("Missing control_image in input.")
|
42 |
image = Image.open(image_file)
|
43 |
w, h = image.size
|
44 |
+
# Reduce image size for memory efficiency
|
45 |
+
image = image.resize((w // 2, h // 2)) # Downscale to save memory
|
46 |
+
return image.resize((w, h)) # Upscale back after processing
|
47 |
|
48 |
def postprocess(self, output):
|
49 |
# Save output image to a file-like object
|
|
|
59 |
output_image = self.pipe(
|
60 |
prompt=data.get("prompt", ""),
|
61 |
control_image=control_image,
|
62 |
+
controlnet_conditioning_scale=0.5, # Slightly reduced for memory efficiency
|
63 |
+
num_inference_steps=15, # Reduced steps to save memory
|
64 |
height=control_image.size[1],
|
65 |
width=control_image.size[0],
|
66 |
).images[0]
|