huyai123 commited on
Commit
41181c9
·
verified ·
1 Parent(s): 1269c65

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +8 -14
handler.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  import torch
3
  from PIL import Image
4
- from diffusers.utils import load_image
5
  from diffusers import FluxControlNetModel
6
  from diffusers.pipelines import FluxControlNetPipeline
7
  from io import BytesIO
@@ -9,12 +8,14 @@ import logging
9
 
10
  class EndpointHandler:
11
  def __init__(self, model_dir="huyai123/Flux.1-dev-Image-Upscaler"):
 
 
 
12
  # Access the environment variable
13
  HF_TOKEN = os.getenv('HF_TOKEN')
14
  if not HF_TOKEN:
15
  raise ValueError("HF_TOKEN environment variable is not set")
16
 
17
- # Log the token for debugging
18
  logging.basicConfig(level=logging.INFO)
19
  logging.info("Using HF_TOKEN")
20
 
@@ -34,42 +35,35 @@ class EndpointHandler:
34
  self.pipe.to("cuda")
35
  self.pipe.enable_attention_slicing("auto")
36
  self.pipe.enable_sequential_cpu_offload()
 
37
 
38
  def preprocess(self, data):
39
- # Load image from file
40
  image_file = data.get("control_image", None)
41
  if not image_file:
42
  raise ValueError("Missing control_image in input.")
43
  image = Image.open(image_file)
44
- w, h = image.size
45
- return image.resize((w * 2, h * 2)) # Reduce upscale factor to save memory
46
 
47
  def postprocess(self, output):
48
- # Save output image to a file-like object
49
  buffer = BytesIO()
50
  output.save(buffer, format="PNG")
51
- buffer.seek(0) # Reset buffer pointer
52
  return buffer
53
 
54
  def inference(self, data):
55
- # Preprocess input
56
  control_image = self.preprocess(data)
57
- # Clear GPU memory
58
  torch.cuda.empty_cache()
59
- # Generate output
60
  output_image = self.pipe(
61
  prompt=data.get("prompt", ""),
62
  control_image=control_image,
63
- controlnet_conditioning_scale=0.5, # Reduced to save memory
64
- num_inference_steps=15, # Reduced steps
65
  height=control_image.size[1],
66
  width=control_image.size[0],
67
  ).images[0]
68
- # Postprocess output
69
  return self.postprocess(output_image)
70
 
71
  if __name__ == "__main__":
72
- # Example usage
73
  data = {'control_image': 'path/to/your/image.png', 'prompt': 'Your prompt here'}
74
  handler = EndpointHandler()
75
  output = handler.inference(data)
 
1
  import os
2
  import torch
3
  from PIL import Image
 
4
  from diffusers import FluxControlNetModel
5
  from diffusers.pipelines import FluxControlNetPipeline
6
  from io import BytesIO
 
8
 
9
  class EndpointHandler:
10
  def __init__(self, model_dir="huyai123/Flux.1-dev-Image-Upscaler"):
11
+ # Set memory limit
12
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
13
+
14
  # Access the environment variable
15
  HF_TOKEN = os.getenv('HF_TOKEN')
16
  if not HF_TOKEN:
17
  raise ValueError("HF_TOKEN environment variable is not set")
18
 
 
19
  logging.basicConfig(level=logging.INFO)
20
  logging.info("Using HF_TOKEN")
21
 
 
35
  self.pipe.to("cuda")
36
  self.pipe.enable_attention_slicing("auto")
37
  self.pipe.enable_sequential_cpu_offload()
38
+ self.pipe.enable_memory_efficient_attention()
39
 
40
  def preprocess(self, data):
 
41
  image_file = data.get("control_image", None)
42
  if not image_file:
43
  raise ValueError("Missing control_image in input.")
44
  image = Image.open(image_file)
45
+ return image.resize((512, 512)) # Resize to reduce memory usage
 
46
 
47
  def postprocess(self, output):
 
48
  buffer = BytesIO()
49
  output.save(buffer, format="PNG")
50
+ buffer.seek(0)
51
  return buffer
52
 
53
  def inference(self, data):
 
54
  control_image = self.preprocess(data)
 
55
  torch.cuda.empty_cache()
 
56
  output_image = self.pipe(
57
  prompt=data.get("prompt", ""),
58
  control_image=control_image,
59
+ controlnet_conditioning_scale=0.5,
60
+ num_inference_steps=10,
61
  height=control_image.size[1],
62
  width=control_image.size[0],
63
  ).images[0]
 
64
  return self.postprocess(output_image)
65
 
66
  if __name__ == "__main__":
 
67
  data = {'control_image': 'path/to/your/image.png', 'prompt': 'Your prompt here'}
68
  handler = EndpointHandler()
69
  output = handler.inference(data)