File size: 2,384 Bytes
ea54d45 b636aa5 ea54d45 b636aa5 906db1e f4b717d b636aa5 906db1e 7754b09 ea54d45 fd8d501 f4b717d fd8d501 ea54d45 b636aa5 7754b09 fd8d501 b636aa5 7754b09 b636aa5 ea54d45 fd8d501 b636aa5 7754b09 b636aa5 906db1e b636aa5 e728996 b636aa5 e728996 b636aa5 e728996 f4b717d fd8d501 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import os
import torch
from PIL import Image
from diffusers.utils import load_image
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from io import BytesIO
import logging
class EndpointHandler:
def __init__(self, model_dir="huyai123/Flux.1-dev-Image-Upscaler"):
# Access the environment variable
HF_TOKEN = os.getenv('HF_TOKEN')
if not HF_TOKEN:
raise ValueError("HF_TOKEN environment variable is not set")
# Log the token for debugging (remove this in production)
logging.basicConfig(level=logging.INFO)
logging.info("Using HF_TOKEN")
# Load model and pipeline
self.controlnet = FluxControlNetModel.from_pretrained(
model_dir, torch_dtype=torch.bfloat16, use_auth_token=HF_TOKEN
)
self.pipe = FluxControlNetPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
controlnet=self.controlnet,
torch_dtype=torch.bfloat16,
use_auth_token=HF_TOKEN
)
self.pipe.to("cuda")
def preprocess(self, data):
# Load image from file
image_file = data.get("control_image", None)
if not image_file:
raise ValueError("Missing control_image in input.")
image = Image.open(image_file)
w, h = image.size
# Upscale x4
return image.resize((w * 4, h * 4))
def postprocess(self, output):
# Save output image to a file-like object
buffer = BytesIO()
output.save(buffer, format="PNG")
buffer.seek(0) # Reset buffer pointer
return buffer
def inference(self, data):
# Preprocess input
control_image = self.preprocess(data)
# Generate output
output_image = self.pipe(
prompt=data.get("prompt", ""),
control_image=control_image,
controlnet_conditioning_scale=0.6,
num_inference_steps=28,
height=control_image.size[1],
width=control_image.size[0],
).images[0]
# Postprocess output
return self.postprocess(output_image)
if __name__ == "__main__":
# Example usage
data = {'control_image': 'path/to/your/image.png', 'prompt': 'Your prompt here'}
handler = EndpointHandler()
output = handler.inference(data)
print(output)
|