Update handler.py
Browse files- handler.py +11 -10
handler.py
CHANGED
@@ -14,11 +14,14 @@ class EndpointHandler:
|
|
14 |
if not HF_TOKEN:
|
15 |
raise ValueError("HF_TOKEN environment variable is not set")
|
16 |
|
17 |
-
# Log the token for debugging
|
18 |
logging.basicConfig(level=logging.INFO)
|
19 |
logging.info("Using HF_TOKEN")
|
20 |
|
21 |
-
#
|
|
|
|
|
|
|
22 |
self.controlnet = FluxControlNetModel.from_pretrained(
|
23 |
model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN
|
24 |
)
|
@@ -29,9 +32,7 @@ class EndpointHandler:
|
|
29 |
use_auth_token=HF_TOKEN
|
30 |
)
|
31 |
self.pipe.to("cuda")
|
32 |
-
|
33 |
-
# Enable memory optimizations
|
34 |
-
self.pipe.enable_attention_slicing()
|
35 |
self.pipe.enable_sequential_cpu_offload()
|
36 |
|
37 |
def preprocess(self, data):
|
@@ -41,9 +42,7 @@ class EndpointHandler:
|
|
41 |
raise ValueError("Missing control_image in input.")
|
42 |
image = Image.open(image_file)
|
43 |
w, h = image.size
|
44 |
-
# Reduce
|
45 |
-
image = image.resize((w // 2, h // 2)) # Downscale to save memory
|
46 |
-
return image.resize((w, h)) # Upscale back after processing
|
47 |
|
48 |
def postprocess(self, output):
|
49 |
# Save output image to a file-like object
|
@@ -55,12 +54,14 @@ class EndpointHandler:
|
|
55 |
def inference(self, data):
|
56 |
# Preprocess input
|
57 |
control_image = self.preprocess(data)
|
|
|
|
|
58 |
# Generate output
|
59 |
output_image = self.pipe(
|
60 |
prompt=data.get("prompt", ""),
|
61 |
control_image=control_image,
|
62 |
-
controlnet_conditioning_scale=0.5, #
|
63 |
-
num_inference_steps=15, # Reduced steps
|
64 |
height=control_image.size[1],
|
65 |
width=control_image.size[0],
|
66 |
).images[0]
|
|
|
14 |
if not HF_TOKEN:
|
15 |
raise ValueError("HF_TOKEN environment variable is not set")
|
16 |
|
17 |
+
# Log the token for debugging
|
18 |
logging.basicConfig(level=logging.INFO)
|
19 |
logging.info("Using HF_TOKEN")
|
20 |
|
21 |
+
# Clear GPU memory
|
22 |
+
torch.cuda.empty_cache()
|
23 |
+
|
24 |
+
# Load model and pipeline
|
25 |
self.controlnet = FluxControlNetModel.from_pretrained(
|
26 |
model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN
|
27 |
)
|
|
|
32 |
use_auth_token=HF_TOKEN
|
33 |
)
|
34 |
self.pipe.to("cuda")
|
35 |
+
self.pipe.enable_attention_slicing("auto")
|
|
|
|
|
36 |
self.pipe.enable_sequential_cpu_offload()
|
37 |
|
38 |
def preprocess(self, data):
|
|
|
42 |
raise ValueError("Missing control_image in input.")
|
43 |
image = Image.open(image_file)
|
44 |
w, h = image.size
|
45 |
+
return image.resize((w * 2, h * 2)) # Reduce upscale factor to save memory
|
|
|
|
|
46 |
|
47 |
def postprocess(self, output):
|
48 |
# Save output image to a file-like object
|
|
|
54 |
def inference(self, data):
|
55 |
# Preprocess input
|
56 |
control_image = self.preprocess(data)
|
57 |
+
# Clear GPU memory
|
58 |
+
torch.cuda.empty_cache()
|
59 |
# Generate output
|
60 |
output_image = self.pipe(
|
61 |
prompt=data.get("prompt", ""),
|
62 |
control_image=control_image,
|
63 |
+
controlnet_conditioning_scale=0.5, # Reduced to save memory
|
64 |
+
num_inference_steps=15, # Reduced steps
|
65 |
height=control_image.size[1],
|
66 |
width=control_image.size[0],
|
67 |
).images[0]
|