|
import torch
|
|
from diffusers.utils import load_image
|
|
from diffusers import FluxControlNetModel
|
|
from diffusers.pipelines import FluxControlNetPipeline
|
|
from PIL import Image
|
|
import io
|
|
|
|
class CustomHandler:
|
|
def __init__(self, model_dir):
|
|
|
|
self.controlnet = FluxControlNetModel.from_pretrained(
|
|
model_dir, torch_dtype=torch.bfloat16
|
|
)
|
|
self.pipe = FluxControlNetPipeline.from_pretrained(
|
|
"black-forest-labs/FLUX.1-dev",
|
|
controlnet=self.controlnet,
|
|
torch_dtype=torch.bfloat16
|
|
)
|
|
self.pipe.to("cuda")
|
|
|
|
def preprocess(self, data):
|
|
|
|
image_file = data.get("control_image", None)
|
|
if not image_file:
|
|
raise ValueError("Missing control_image in input.")
|
|
image = Image.open(image_file)
|
|
w, h = image.size
|
|
|
|
return image.resize((w * 4, h * 4))
|
|
|
|
def postprocess(self, output):
|
|
|
|
buffer = io.BytesIO()
|
|
output.save(buffer, format="PNG")
|
|
buffer.seek(0)
|
|
return buffer
|
|
|
|
def inference(self, data):
|
|
|
|
control_image = self.preprocess(data)
|
|
|
|
output_image = self.pipe(
|
|
prompt=data.get("prompt", ""),
|
|
control_image=control_image,
|
|
controlnet_conditioning_scale=0.6,
|
|
num_inference_steps=28,
|
|
guidance_scale=3.5,
|
|
height=control_image.size[1],
|
|
width=control_image.size[0],
|
|
).images[0]
|
|
|
|
return self.postprocess(output_image)
|
|
|