File size: 2,638 Bytes
ea54d45 b636aa5 ea54d45 b636aa5 906db1e f4b717d b636aa5 906db1e 7754b09 ea54d45 fd8d501 f4b717d 1269c65 f4b717d fd8d501 ea54d45 1269c65 7754b09 edd8452 b636aa5 7754b09 b636aa5 edd8452 fd8d501 b636aa5 1269c65 a44f1cf b636aa5 1269c65 b636aa5 7754b09 b636aa5 906db1e b636aa5 1269c65 b636aa5 e728996 1269c65 e728996 b636aa5 e728996 f4b717d fd8d501 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import os
import torch
from PIL import Image
from diffusers.utils import load_image
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from io import BytesIO
import logging
class EndpointHandler:
def __init__(self, model_dir="huyai123/Flux.1-dev-Image-Upscaler"):
# Access the environment variable
HF_TOKEN = os.getenv('HF_TOKEN')
if not HF_TOKEN:
raise ValueError("HF_TOKEN environment variable is not set")
# Log the token for debugging
logging.basicConfig(level=logging.INFO)
logging.info("Using HF_TOKEN")
# Clear GPU memory
torch.cuda.empty_cache()
# Load model and pipeline
self.controlnet = FluxControlNetModel.from_pretrained(
model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN
)
self.pipe = FluxControlNetPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
controlnet=self.controlnet,
torch_dtype=torch.float16,
use_auth_token=HF_TOKEN
)
self.pipe.to("cuda")
self.pipe.enable_attention_slicing("auto")
self.pipe.enable_sequential_cpu_offload()
def preprocess(self, data):
# Load image from file
image_file = data.get("control_image", None)
if not image_file:
raise ValueError("Missing control_image in input.")
image = Image.open(image_file)
w, h = image.size
return image.resize((w * 2, h * 2)) # Reduce upscale factor to save memory
def postprocess(self, output):
# Save output image to a file-like object
buffer = BytesIO()
output.save(buffer, format="PNG")
buffer.seek(0) # Reset buffer pointer
return buffer
def inference(self, data):
# Preprocess input
control_image = self.preprocess(data)
# Clear GPU memory
torch.cuda.empty_cache()
# Generate output
output_image = self.pipe(
prompt=data.get("prompt", ""),
control_image=control_image,
controlnet_conditioning_scale=0.5, # Reduced to save memory
num_inference_steps=15, # Reduced steps
height=control_image.size[1],
width=control_image.size[0],
).images[0]
# Postprocess output
return self.postprocess(output_image)
if __name__ == "__main__":
# Example usage
data = {'control_image': 'path/to/your/image.png', 'prompt': 'Your prompt here'}
handler = EndpointHandler()
output = handler.inference(data)
print(output)
|