huyai123 commited on
Commit
59d59d3
·
verified ·
1 Parent(s): a171858

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.json +21 -0
  2. diffusion_pytorch_model.safetensors +3 -0
  3. handler.py +52 -0
config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FluxControlNetModel",
3
+ "_diffusers_version": "0.31.0.dev0",
4
+ "_name_or_path": "/data/checkpoints/flux_controlnet_hf/controlnet_upscaling//",
5
+ "attention_head_dim": 128,
6
+ "axes_dims_rope": [
7
+ 16,
8
+ 56,
9
+ 56
10
+ ],
11
+ "guidance_embeds": true,
12
+ "in_channels": 64,
13
+ "joint_attention_dim": 4096,
14
+ "num_attention_heads": 24,
15
+ "num_layers": 5,
16
+ "num_mode": null,
17
+ "num_single_layers": 0,
18
+ "patch_size": 1,
19
+ "pooled_projection_dim": 768
20
+ }
21
+
diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a7ea24d2037ff2aa4d25f8b4ce9fe7e739a2cfe6b9d05106788005d5058c8ca
3
+ size 3583232168
handler.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.utils import load_image
3
+ from diffusers import FluxControlNetModel
4
+ from diffusers.pipelines import FluxControlNetPipeline
5
+ from PIL import Image
6
+ import io
7
+
8
+ class CustomHandler:
9
+ def __init__(self, model_dir):
10
+ # Load model and pipeline
11
+ self.controlnet = FluxControlNetModel.from_pretrained(
12
+ model_dir, torch_dtype=torch.bfloat16
13
+ )
14
+ self.pipe = FluxControlNetPipeline.from_pretrained(
15
+ "black-forest-labs/FLUX.1-dev",
16
+ controlnet=self.controlnet,
17
+ torch_dtype=torch.bfloat16
18
+ )
19
+ self.pipe.to("cuda")
20
+
21
+ def preprocess(self, data):
22
+ # Load image from file
23
+ image_file = data.get("control_image", None)
24
+ if not image_file:
25
+ raise ValueError("Missing control_image in input.")
26
+ image = Image.open(image_file)
27
+ w, h = image.size
28
+ # Upscale x4
29
+ return image.resize((w * 4, h * 4))
30
+
31
+ def postprocess(self, output):
32
+ # Save output image to a file-like object
33
+ buffer = io.BytesIO()
34
+ output.save(buffer, format="PNG")
35
+ buffer.seek(0) # Reset buffer pointer
36
+ return buffer
37
+
38
+ def inference(self, data):
39
+ # Preprocess input
40
+ control_image = self.preprocess(data)
41
+ # Generate output
42
+ output_image = self.pipe(
43
+ prompt=data.get("prompt", ""),
44
+ control_image=control_image,
45
+ controlnet_conditioning_scale=0.6,
46
+ num_inference_steps=28,
47
+ guidance_scale=3.5,
48
+ height=control_image.size[1],
49
+ width=control_image.size[0],
50
+ ).images[0]
51
+ # Postprocess output
52
+ return self.postprocess(output_image)