Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -120,6 +120,7 @@ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
|
|
| 120 |
# Load model directly
|
| 121 |
from transformers import AutoModelForImageSegmentation
|
| 122 |
rmbg = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
|
|
|
|
| 123 |
|
| 124 |
model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
|
| 125 |
model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
|
|
@@ -128,6 +129,7 @@ model.eval()
|
|
| 128 |
|
| 129 |
# Change UNet
|
| 130 |
|
|
|
|
| 131 |
with torch.no_grad():
|
| 132 |
new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
|
| 133 |
new_conv_in.weight.zero_()
|
|
@@ -314,7 +316,7 @@ def encode_prompt_pair(positive_prompt, negative_prompt):
|
|
| 314 |
|
| 315 |
return c, uc
|
| 316 |
|
| 317 |
-
|
| 318 |
@torch.inference_mode()
|
| 319 |
def pytorch2numpy(imgs, quant=True):
|
| 320 |
results = []
|
|
@@ -331,7 +333,7 @@ def pytorch2numpy(imgs, quant=True):
|
|
| 331 |
results.append(y)
|
| 332 |
return results
|
| 333 |
|
| 334 |
-
|
| 335 |
@torch.inference_mode()
|
| 336 |
def numpy2pytorch(imgs):
|
| 337 |
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
|
|
@@ -359,7 +361,7 @@ def resize_without_crop(image, target_width, target_height):
|
|
| 359 |
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
|
| 360 |
return np.array(resized_image)
|
| 361 |
|
| 362 |
-
|
| 363 |
@torch.inference_mode()
|
| 364 |
def run_rmbg(img, sigma=0.0):
|
| 365 |
# Convert RGBA to RGB if needed
|
|
@@ -384,6 +386,8 @@ def run_rmbg(img, sigma=0.0):
|
|
| 384 |
rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
|
| 385 |
result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
|
| 386 |
return result.clip(0, 255).astype(np.uint8), rgba
|
|
|
|
|
|
|
| 387 |
@torch.inference_mode()
|
| 388 |
def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
|
| 389 |
clear_memory()
|
|
|
|
| 120 |
# Load model directly
|
| 121 |
from transformers import AutoModelForImageSegmentation
|
| 122 |
rmbg = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
|
| 123 |
+
rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
|
| 124 |
|
| 125 |
model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
|
| 126 |
model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
|
|
|
|
| 129 |
|
| 130 |
# Change UNet
|
| 131 |
|
| 132 |
+
|
| 133 |
with torch.no_grad():
|
| 134 |
new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
|
| 135 |
new_conv_in.weight.zero_()
|
|
|
|
| 316 |
|
| 317 |
return c, uc
|
| 318 |
|
| 319 |
+
@spaces.GPU(duration=60)
|
| 320 |
@torch.inference_mode()
|
| 321 |
def pytorch2numpy(imgs, quant=True):
|
| 322 |
results = []
|
|
|
|
| 333 |
results.append(y)
|
| 334 |
return results
|
| 335 |
|
| 336 |
+
@spaces.GPU(duration=60)
|
| 337 |
@torch.inference_mode()
|
| 338 |
def numpy2pytorch(imgs):
|
| 339 |
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
|
|
|
|
| 361 |
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
|
| 362 |
return np.array(resized_image)
|
| 363 |
|
| 364 |
+
@spaces.GPU(duration=60)
|
| 365 |
@torch.inference_mode()
|
| 366 |
def run_rmbg(img, sigma=0.0):
|
| 367 |
# Convert RGBA to RGB if needed
|
|
|
|
| 386 |
rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
|
| 387 |
result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
|
| 388 |
return result.clip(0, 255).astype(np.uint8), rgba
|
| 389 |
+
|
| 390 |
+
@spaces.GPU(duration=60)
|
| 391 |
@torch.inference_mode()
|
| 392 |
def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
|
| 393 |
clear_memory()
|