File size: 2,638 Bytes
ea54d45
b636aa5
ea54d45
b636aa5
 
 
906db1e
f4b717d
b636aa5
906db1e
7754b09
ea54d45
fd8d501
 
 
f4b717d
1269c65
f4b717d
fd8d501
ea54d45
1269c65
 
 
 
7754b09
edd8452
b636aa5
7754b09
b636aa5
 
edd8452
fd8d501
b636aa5
 
1269c65
a44f1cf
 
b636aa5
 
 
 
 
 
 
1269c65
b636aa5
7754b09
b636aa5
906db1e
b636aa5
 
 
 
 
 
 
1269c65
 
b636aa5
e728996
 
 
1269c65
 
e728996
 
b636aa5
 
e728996
f4b717d
 
 
 
 
 
fd8d501
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import torch
from PIL import Image
from diffusers.utils import load_image
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from io import BytesIO
import logging

class EndpointHandler:
    def __init__(self, model_dir="huyai123/Flux.1-dev-Image-Upscaler"):
        # Access the environment variable
        HF_TOKEN = os.getenv('HF_TOKEN')
        if not HF_TOKEN:
            raise ValueError("HF_TOKEN environment variable is not set")

        # Log the token for debugging
        logging.basicConfig(level=logging.INFO)
        logging.info("Using HF_TOKEN")

        # Clear GPU memory
        torch.cuda.empty_cache()

        # Load model and pipeline
        self.controlnet = FluxControlNetModel.from_pretrained(
            model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN
        )
        self.pipe = FluxControlNetPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            controlnet=self.controlnet,
            torch_dtype=torch.float16,
            use_auth_token=HF_TOKEN
        )
        self.pipe.to("cuda")
        self.pipe.enable_attention_slicing("auto")
        self.pipe.enable_sequential_cpu_offload()

    def preprocess(self, data):
        # Load image from file
        image_file = data.get("control_image", None)
        if not image_file:
            raise ValueError("Missing control_image in input.")
        image = Image.open(image_file)
        w, h = image.size
        return image.resize((w * 2, h * 2))  # Reduce upscale factor to save memory

    def postprocess(self, output):
        # Save output image to a file-like object
        buffer = BytesIO()
        output.save(buffer, format="PNG")
        buffer.seek(0)  # Reset buffer pointer
        return buffer

    def inference(self, data):
        # Preprocess input
        control_image = self.preprocess(data)
        # Clear GPU memory
        torch.cuda.empty_cache()
        # Generate output
        output_image = self.pipe(
            prompt=data.get("prompt", ""),
            control_image=control_image,
            controlnet_conditioning_scale=0.5,  # Reduced to save memory
            num_inference_steps=15,  # Reduced steps
            height=control_image.size[1],
            width=control_image.size[0],
        ).images[0]
        # Postprocess output
        return self.postprocess(output_image)

if __name__ == "__main__":
    # Example usage
    data = {'control_image': 'path/to/your/image.png', 'prompt': 'Your prompt here'}
    handler = EndpointHandler()
    output = handler.inference(data)
    print(output)