Spaces:
Build error
Build error
Update flux/util.py
Browse files- flux/util.py +0 -45
flux/util.py
CHANGED
|
@@ -4,7 +4,6 @@ from dataclasses import dataclass
|
|
| 4 |
import torch
|
| 5 |
from einops import rearrange
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
-
from imwatermark import WatermarkEncoder
|
| 8 |
from safetensors.torch import load_file as load_sft
|
| 9 |
|
| 10 |
from flux.model import Flux, FluxParams
|
|
@@ -155,47 +154,3 @@ def load_ae(name: str, device: str = "cuda", hf_download: bool = True) -> AutoEn
|
|
| 155 |
missing, unexpected = ae.load_state_dict(sd, strict=False)
|
| 156 |
print_load_warning(missing, unexpected)
|
| 157 |
return ae
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
class WatermarkEmbedder:
|
| 161 |
-
def __init__(self, watermark):
|
| 162 |
-
self.watermark = watermark
|
| 163 |
-
self.num_bits = len(WATERMARK_BITS)
|
| 164 |
-
self.encoder = WatermarkEncoder()
|
| 165 |
-
self.encoder.set_watermark("bits", self.watermark)
|
| 166 |
-
|
| 167 |
-
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
| 168 |
-
"""
|
| 169 |
-
Adds a predefined watermark to the input image
|
| 170 |
-
|
| 171 |
-
Args:
|
| 172 |
-
image: ([N,] B, RGB, H, W) in range [-1, 1]
|
| 173 |
-
|
| 174 |
-
Returns:
|
| 175 |
-
same as input but watermarked
|
| 176 |
-
"""
|
| 177 |
-
image = 0.5 * image + 0.5
|
| 178 |
-
squeeze = len(image.shape) == 4
|
| 179 |
-
if squeeze:
|
| 180 |
-
image = image[None, ...]
|
| 181 |
-
n = image.shape[0]
|
| 182 |
-
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
|
| 183 |
-
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
| 184 |
-
# watermarking libary expects input as cv2 BGR format
|
| 185 |
-
for k in range(image_np.shape[0]):
|
| 186 |
-
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
| 187 |
-
image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
|
| 188 |
-
image.device
|
| 189 |
-
)
|
| 190 |
-
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
| 191 |
-
if squeeze:
|
| 192 |
-
image = image[0]
|
| 193 |
-
image = 2 * image - 1
|
| 194 |
-
return image
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
# A fixed 48-bit message that was choosen at random
|
| 198 |
-
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
|
| 199 |
-
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
| 200 |
-
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
| 201 |
-
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
|
|
|
| 4 |
import torch
|
| 5 |
from einops import rearrange
|
| 6 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 7 |
from safetensors.torch import load_file as load_sft
|
| 8 |
|
| 9 |
from flux.model import Flux, FluxParams
|
|
|
|
| 154 |
missing, unexpected = ae.load_state_dict(sd, strict=False)
|
| 155 |
print_load_warning(missing, unexpected)
|
| 156 |
return ae
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|