Update handler.py
Browse files- handler.py +11 -10
handler.py
CHANGED
|
@@ -14,11 +14,14 @@ class EndpointHandler:
|
|
| 14 |
if not HF_TOKEN:
|
| 15 |
raise ValueError("HF_TOKEN environment variable is not set")
|
| 16 |
|
| 17 |
-
# Log the token for debugging
|
| 18 |
logging.basicConfig(level=logging.INFO)
|
| 19 |
logging.info("Using HF_TOKEN")
|
| 20 |
|
| 21 |
-
#
|
|
|
|
|
|
|
|
|
|
| 22 |
self.controlnet = FluxControlNetModel.from_pretrained(
|
| 23 |
model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN
|
| 24 |
)
|
|
@@ -29,9 +32,7 @@ class EndpointHandler:
|
|
| 29 |
use_auth_token=HF_TOKEN
|
| 30 |
)
|
| 31 |
self.pipe.to("cuda")
|
| 32 |
-
|
| 33 |
-
# Enable memory optimizations
|
| 34 |
-
self.pipe.enable_attention_slicing()
|
| 35 |
self.pipe.enable_sequential_cpu_offload()
|
| 36 |
|
| 37 |
def preprocess(self, data):
|
|
@@ -41,9 +42,7 @@ class EndpointHandler:
|
|
| 41 |
raise ValueError("Missing control_image in input.")
|
| 42 |
image = Image.open(image_file)
|
| 43 |
w, h = image.size
|
| 44 |
-
# Reduce
|
| 45 |
-
image = image.resize((w // 2, h // 2)) # Downscale to save memory
|
| 46 |
-
return image.resize((w, h)) # Upscale back after processing
|
| 47 |
|
| 48 |
def postprocess(self, output):
|
| 49 |
# Save output image to a file-like object
|
|
@@ -55,12 +54,14 @@ class EndpointHandler:
|
|
| 55 |
def inference(self, data):
|
| 56 |
# Preprocess input
|
| 57 |
control_image = self.preprocess(data)
|
|
|
|
|
|
|
| 58 |
# Generate output
|
| 59 |
output_image = self.pipe(
|
| 60 |
prompt=data.get("prompt", ""),
|
| 61 |
control_image=control_image,
|
| 62 |
-
controlnet_conditioning_scale=0.5, #
|
| 63 |
-
num_inference_steps=15, # Reduced steps
|
| 64 |
height=control_image.size[1],
|
| 65 |
width=control_image.size[0],
|
| 66 |
).images[0]
|
|
|
|
| 14 |
if not HF_TOKEN:
|
| 15 |
raise ValueError("HF_TOKEN environment variable is not set")
|
| 16 |
|
| 17 |
+
# Log the token for debugging
|
| 18 |
logging.basicConfig(level=logging.INFO)
|
| 19 |
logging.info("Using HF_TOKEN")
|
| 20 |
|
| 21 |
+
# Clear GPU memory
|
| 22 |
+
torch.cuda.empty_cache()
|
| 23 |
+
|
| 24 |
+
# Load model and pipeline
|
| 25 |
self.controlnet = FluxControlNetModel.from_pretrained(
|
| 26 |
model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN
|
| 27 |
)
|
|
|
|
| 32 |
use_auth_token=HF_TOKEN
|
| 33 |
)
|
| 34 |
self.pipe.to("cuda")
|
| 35 |
+
self.pipe.enable_attention_slicing("auto")
|
|
|
|
|
|
|
| 36 |
self.pipe.enable_sequential_cpu_offload()
|
| 37 |
|
| 38 |
def preprocess(self, data):
|
|
|
|
| 42 |
raise ValueError("Missing control_image in input.")
|
| 43 |
image = Image.open(image_file)
|
| 44 |
w, h = image.size
|
| 45 |
+
return image.resize((w * 2, h * 2)) # Reduce upscale factor to save memory
|
|
|
|
|
|
|
| 46 |
|
| 47 |
def postprocess(self, output):
|
| 48 |
# Save output image to a file-like object
|
|
|
|
| 54 |
def inference(self, data):
|
| 55 |
# Preprocess input
|
| 56 |
control_image = self.preprocess(data)
|
| 57 |
+
# Clear GPU memory
|
| 58 |
+
torch.cuda.empty_cache()
|
| 59 |
# Generate output
|
| 60 |
output_image = self.pipe(
|
| 61 |
prompt=data.get("prompt", ""),
|
| 62 |
control_image=control_image,
|
| 63 |
+
controlnet_conditioning_scale=0.5, # Reduced to save memory
|
| 64 |
+
num_inference_steps=15, # Reduced steps
|
| 65 |
height=control_image.size[1],
|
| 66 |
width=control_image.size[0],
|
| 67 |
).images[0]
|