Update handler.py
Browse files- handler.py +8 -14
handler.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
from PIL import Image
|
4 |
-
from diffusers.utils import load_image
|
5 |
from diffusers import FluxControlNetModel
|
6 |
from diffusers.pipelines import FluxControlNetPipeline
|
7 |
from io import BytesIO
|
@@ -9,12 +8,14 @@ import logging
|
|
9 |
|
10 |
class EndpointHandler:
|
11 |
def __init__(self, model_dir="huyai123/Flux.1-dev-Image-Upscaler"):
|
|
|
|
|
|
|
12 |
# Access the environment variable
|
13 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
14 |
if not HF_TOKEN:
|
15 |
raise ValueError("HF_TOKEN environment variable is not set")
|
16 |
|
17 |
-
# Log the token for debugging
|
18 |
logging.basicConfig(level=logging.INFO)
|
19 |
logging.info("Using HF_TOKEN")
|
20 |
|
@@ -34,42 +35,35 @@ class EndpointHandler:
|
|
34 |
self.pipe.to("cuda")
|
35 |
self.pipe.enable_attention_slicing("auto")
|
36 |
self.pipe.enable_sequential_cpu_offload()
|
|
|
37 |
|
38 |
def preprocess(self, data):
|
39 |
-
# Load image from file
|
40 |
image_file = data.get("control_image", None)
|
41 |
if not image_file:
|
42 |
raise ValueError("Missing control_image in input.")
|
43 |
image = Image.open(image_file)
|
44 |
-
|
45 |
-
return image.resize((w * 2, h * 2)) # Reduce upscale factor to save memory
|
46 |
|
47 |
def postprocess(self, output):
|
48 |
-
# Save output image to a file-like object
|
49 |
buffer = BytesIO()
|
50 |
output.save(buffer, format="PNG")
|
51 |
-
buffer.seek(0)
|
52 |
return buffer
|
53 |
|
54 |
def inference(self, data):
|
55 |
-
# Preprocess input
|
56 |
control_image = self.preprocess(data)
|
57 |
-
# Clear GPU memory
|
58 |
torch.cuda.empty_cache()
|
59 |
-
# Generate output
|
60 |
output_image = self.pipe(
|
61 |
prompt=data.get("prompt", ""),
|
62 |
control_image=control_image,
|
63 |
-
controlnet_conditioning_scale=0.5,
|
64 |
-
num_inference_steps=
|
65 |
height=control_image.size[1],
|
66 |
width=control_image.size[0],
|
67 |
).images[0]
|
68 |
-
# Postprocess output
|
69 |
return self.postprocess(output_image)
|
70 |
|
71 |
if __name__ == "__main__":
|
72 |
-
# Example usage
|
73 |
data = {'control_image': 'path/to/your/image.png', 'prompt': 'Your prompt here'}
|
74 |
handler = EndpointHandler()
|
75 |
output = handler.inference(data)
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
from PIL import Image
|
|
|
4 |
from diffusers import FluxControlNetModel
|
5 |
from diffusers.pipelines import FluxControlNetPipeline
|
6 |
from io import BytesIO
|
|
|
8 |
|
9 |
class EndpointHandler:
|
10 |
def __init__(self, model_dir="huyai123/Flux.1-dev-Image-Upscaler"):
|
11 |
+
# Set memory limit
|
12 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
13 |
+
|
14 |
# Access the environment variable
|
15 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
16 |
if not HF_TOKEN:
|
17 |
raise ValueError("HF_TOKEN environment variable is not set")
|
18 |
|
|
|
19 |
logging.basicConfig(level=logging.INFO)
|
20 |
logging.info("Using HF_TOKEN")
|
21 |
|
|
|
35 |
self.pipe.to("cuda")
|
36 |
self.pipe.enable_attention_slicing("auto")
|
37 |
self.pipe.enable_sequential_cpu_offload()
|
38 |
+
self.pipe.enable_memory_efficient_attention()
|
39 |
|
40 |
def preprocess(self, data):
|
|
|
41 |
image_file = data.get("control_image", None)
|
42 |
if not image_file:
|
43 |
raise ValueError("Missing control_image in input.")
|
44 |
image = Image.open(image_file)
|
45 |
+
return image.resize((512, 512)) # Resize to reduce memory usage
|
|
|
46 |
|
47 |
def postprocess(self, output):
|
|
|
48 |
buffer = BytesIO()
|
49 |
output.save(buffer, format="PNG")
|
50 |
+
buffer.seek(0)
|
51 |
return buffer
|
52 |
|
53 |
def inference(self, data):
|
|
|
54 |
control_image = self.preprocess(data)
|
|
|
55 |
torch.cuda.empty_cache()
|
|
|
56 |
output_image = self.pipe(
|
57 |
prompt=data.get("prompt", ""),
|
58 |
control_image=control_image,
|
59 |
+
controlnet_conditioning_scale=0.5,
|
60 |
+
num_inference_steps=10,
|
61 |
height=control_image.size[1],
|
62 |
width=control_image.size[0],
|
63 |
).images[0]
|
|
|
64 |
return self.postprocess(output_image)
|
65 |
|
66 |
if __name__ == "__main__":
|
|
|
67 |
data = {'control_image': 'path/to/your/image.png', 'prompt': 'Your prompt here'}
|
68 |
handler = EndpointHandler()
|
69 |
output = handler.inference(data)
|