import os import torch from PIL import Image from diffusers.utils import load_image from diffusers import FluxControlNetModel from diffusers.pipelines import FluxControlNetPipeline from io import BytesIO import logging class EndpointHandler: def __init__(self, model_dir="huyai123/Flux.1-dev-Image-Upscaler"): # Access the environment variable HF_TOKEN = os.getenv('HF_TOKEN') if not HF_TOKEN: raise ValueError("HF_TOKEN environment variable is not set") # Log the token for debugging logging.basicConfig(level=logging.INFO) logging.info("Using HF_TOKEN") # Clear GPU memory torch.cuda.empty_cache() # Load model and pipeline self.controlnet = FluxControlNetModel.from_pretrained( model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN ) self.pipe = FluxControlNetPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", controlnet=self.controlnet, torch_dtype=torch.float16, use_auth_token=HF_TOKEN ) self.pipe.to("cuda") self.pipe.enable_attention_slicing("auto") self.pipe.enable_sequential_cpu_offload() def preprocess(self, data): # Load image from file image_file = data.get("control_image", None) if not image_file: raise ValueError("Missing control_image in input.") image = Image.open(image_file) w, h = image.size return image.resize((w * 2, h * 2)) # Reduce upscale factor to save memory def postprocess(self, output): # Save output image to a file-like object buffer = BytesIO() output.save(buffer, format="PNG") buffer.seek(0) # Reset buffer pointer return buffer def inference(self, data): # Preprocess input control_image = self.preprocess(data) # Clear GPU memory torch.cuda.empty_cache() # Generate output output_image = self.pipe( prompt=data.get("prompt", ""), control_image=control_image, controlnet_conditioning_scale=0.5, # Reduced to save memory num_inference_steps=15, # Reduced steps height=control_image.size[1], width=control_image.size[0], ).images[0] # Postprocess output return self.postprocess(output_image) if __name__ == "__main__": # Example usage data = {'control_image': 'path/to/your/image.png', 'prompt': 'Your prompt here'} handler = EndpointHandler() output = handler.inference(data) print(output)