File size: 2,384 Bytes
ea54d45
b636aa5
ea54d45
b636aa5
 
 
906db1e
f4b717d
b636aa5
906db1e
7754b09
ea54d45
fd8d501
 
 
f4b717d
 
 
fd8d501
ea54d45
b636aa5
7754b09
fd8d501
b636aa5
7754b09
b636aa5
 
ea54d45
fd8d501
b636aa5
 
 
 
 
 
 
 
 
 
 
 
 
7754b09
b636aa5
906db1e
b636aa5
 
 
 
 
 
 
 
e728996
 
 
 
b636aa5
e728996
 
b636aa5
 
e728996
f4b717d
 
 
 
 
 
fd8d501
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import torch
from PIL import Image
from diffusers.utils import load_image
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from io import BytesIO
import logging

class EndpointHandler:
    def __init__(self, model_dir="huyai123/Flux.1-dev-Image-Upscaler"):
        # Access the environment variable
        HF_TOKEN = os.getenv('HF_TOKEN')
        if not HF_TOKEN:
            raise ValueError("HF_TOKEN environment variable is not set")

        # Log the token for debugging (remove this in production)
        logging.basicConfig(level=logging.INFO)
        logging.info("Using HF_TOKEN")

        # Load model and pipeline
        self.controlnet = FluxControlNetModel.from_pretrained(
            model_dir, torch_dtype=torch.bfloat16, use_auth_token=HF_TOKEN
        )
        self.pipe = FluxControlNetPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            controlnet=self.controlnet,
            torch_dtype=torch.bfloat16,
            use_auth_token=HF_TOKEN
        )
        self.pipe.to("cuda")

    def preprocess(self, data):
        # Load image from file
        image_file = data.get("control_image", None)
        if not image_file:
            raise ValueError("Missing control_image in input.")
        image = Image.open(image_file)
        w, h = image.size
        # Upscale x4
        return image.resize((w * 4, h * 4))

    def postprocess(self, output):
        # Save output image to a file-like object
        buffer = BytesIO()
        output.save(buffer, format="PNG")
        buffer.seek(0)  # Reset buffer pointer
        return buffer

    def inference(self, data):
        # Preprocess input
        control_image = self.preprocess(data)
        # Generate output
        output_image = self.pipe(
            prompt=data.get("prompt", ""),
            control_image=control_image,
            controlnet_conditioning_scale=0.6,
            num_inference_steps=28,
            height=control_image.size[1],
            width=control_image.size[0],
        ).images[0]
        # Postprocess output
        return self.postprocess(output_image)

if __name__ == "__main__":
    # Example usage
    data = {'control_image': 'path/to/your/image.png', 'prompt': 'Your prompt here'}
    handler = EndpointHandler()
    output = handler.inference(data)
    print(output)