File size: 7,640 Bytes
c85d0ce aa7cfca c85d0ce 20b04f8 c85d0ce aa7cfca c85d0ce 20b04f8 c85d0ce aa7cfca 20b04f8 aa7cfca c85d0ce aa7cfca 20b04f8 aa7cfca c85d0ce aa7cfca c85d0ce aa7cfca 20b04f8 aa7cfca 20b04f8 aa7cfca c85d0ce aa7cfca c85d0ce aa7cfca c85d0ce aa7cfca c85d0ce aa7cfca c85d0ce 20b04f8 aa7cfca 20b04f8 aa7cfca 20b04f8 c85d0ce aa7cfca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# The file source is from the [ESRGAN](https://github.com/xinntao/ESRGAN) project
# forked by authors [joeyballentine](https://github.com/joeyballentine/ESRGAN) and [BlueAmulet](https://github.com/BlueAmulet/ESRGAN).
import gc
import numpy as np
import torch
def bgr_to_rgb(image: torch.Tensor) -> torch.Tensor:
# flip image channels
# https://github.com/pytorch/pytorch/issues/229
out: torch.Tensor = image.flip(-3)
# out: torch.Tensor = image[[2, 1, 0], :, :] #RGB to BGR #may be faster
return out
def rgb_to_bgr(image: torch.Tensor) -> torch.Tensor:
# same operation as bgr_to_rgb(), flip image channels
return bgr_to_rgb(image)
def bgra_to_rgba(image: torch.Tensor) -> torch.Tensor:
out: torch.Tensor = image[[2, 1, 0, 3], :, :]
return out
def rgba_to_bgra(image: torch.Tensor) -> torch.Tensor:
# same operation as bgra_to_rgba(), flip image channels
return bgra_to_rgba(image)
def auto_split_upscale(
lr_img: np.ndarray,
upscale_function,
scale: int = 4,
overlap: int = 32,
# A heuristic to proactively split tiles that are too large, avoiding a CUDA error.
# The default (2048*2048) is a conservative value for moderate VRAM (e.g., 8-12GB).
# Adjust this based on your GPU and model's memory footprint.
max_tile_pixels: int = 4194304, # Default: 2048 * 2048 pixels
# Internal parameters for recursion state. Do not set these manually.
known_max_depth: int = None,
current_depth: int = 1,
current_tile: int = 1, # Tracks the current tile being processed
total_tiles: int = 1, # Total number of tiles at this depth level
):
# --- Step 0: Handle CPU-only environment ---
# The entire splitting logic is designed to overcome GPU VRAM limitations.
# If no CUDA-enabled GPU is present, this logic is unnecessary and adds overhead.
# Therefore, we process the image in one go on the CPU.
if not torch.cuda.is_available():
# Note: This assumes the image fits into system RAM, which is usually the case.
result, _ = upscale_function(lr_img, scale)
# The conceptual depth is 1 since no splitting was performed.
return result, 1
"""
Automatically splits an image into tiles for upscaling to avoid CUDA out-of-memory errors.
It uses a combination of a pixel-count heuristic and reactive error handling to find the
optimal processing depth, then applies this depth to all subsequent tiles.
"""
input_h, input_w, input_c = lr_img.shape
# --- Step 1: Decide if we should ATTEMPT to upscale or MUST split ---
# We must split if:
# A) The tile is too large based on our heuristic, and we don't have a known working depth yet.
# B) We have a known working depth from a sibling tile, but we haven't recursed deep enough to reach it yet.
must_split = (known_max_depth is None and (input_h * input_w) > max_tile_pixels) or \
(known_max_depth is not None and current_depth < known_max_depth)
if not must_split:
# If we are not forced to split, let's try to upscale the current tile.
try:
print(f"auto_split_upscale depth: {current_depth}", end=" ", flush=True)
result, _ = upscale_function(lr_img, scale)
# SUCCESS! The upscale worked at this depth.
print(f"progress: {current_tile}/{total_tiles}")
# Return the result and the current depth, which is now the "known_max_depth".
return result, current_depth
except RuntimeError as e:
# Check to see if its actually the CUDA out of memory error
if "CUDA" in str(e):
# OOM ERROR. Our heuristic was too optimistic. This depth is not viable.
print("RuntimeError: CUDA out of memory...")
# Clean up VRAM and proceed to the splitting logic below.
torch.cuda.empty_cache()
gc.collect()
else:
# A different runtime error occurred, so we should not suppress it.
raise RuntimeError(e)
# If an OOM error occurred, flow continues to the splitting section.
# --- Step 2: If we reached here, we MUST split the image ---
# Safety break to prevent infinite recursion if something goes wrong.
if current_depth > 10:
raise RuntimeError("Maximum recursion depth exceeded. Check max_tile_pixels or model requirements.")
# Prepare parameters for the next level of recursion.
next_depth = current_depth + 1
new_total_tiles = total_tiles * 4
base_tile_for_next_level = (current_tile - 1) * 4
# Announce the split only when it's happening.
print(f"Splitting tile at depth {current_depth} into 4 tiles for depth {next_depth}.")
# Split the image into 4 quadrants with overlap.
top_left = lr_img[: input_h // 2 + overlap, : input_w // 2 + overlap, :]
top_right = lr_img[: input_h // 2 + overlap, input_w // 2 - overlap :, :]
bottom_left = lr_img[input_h // 2 - overlap :, : input_w // 2 + overlap, :]
bottom_right = lr_img[input_h // 2 - overlap :, input_w // 2 - overlap :, :]
# Recursively process each quadrant.
# Process the first quadrant to discover the safe depth.
# The first quadrant (top_left) will "discover" the correct processing depth.
# Pass the current `known_max_depth` down.
top_left_rlt, discovered_depth = auto_split_upscale(
top_left, upscale_function, scale=scale, overlap=overlap,
max_tile_pixels=max_tile_pixels,
known_max_depth=known_max_depth,
current_depth=next_depth,
current_tile=base_tile_for_next_level + 1,
total_tiles=new_total_tiles,
)
# Once the depth is discovered, pass it to the other quadrants to avoid redundant checks.
top_right_rlt, _ = auto_split_upscale(
top_right, upscale_function, scale=scale, overlap=overlap,
max_tile_pixels=max_tile_pixels,
known_max_depth=discovered_depth,
current_depth=next_depth,
current_tile=base_tile_for_next_level + 2,
total_tiles=new_total_tiles,
)
bottom_left_rlt, _ = auto_split_upscale(
bottom_left, upscale_function, scale=scale, overlap=overlap,
max_tile_pixels=max_tile_pixels,
known_max_depth=discovered_depth,
current_depth=next_depth,
current_tile=base_tile_for_next_level + 3,
total_tiles=new_total_tiles,
)
bottom_right_rlt, _ = auto_split_upscale(
bottom_right, upscale_function, scale=scale, overlap=overlap,
max_tile_pixels=max_tile_pixels,
known_max_depth=discovered_depth,
current_depth=next_depth,
current_tile=base_tile_for_next_level + 4,
total_tiles=new_total_tiles,
)
# --- Step 3: Stitch the results back together ---
# Reassemble the upscaled quadrants into a single image.
out_h = input_h * scale
out_w = input_w * scale
# Create an empty output image
output_img = np.zeros((out_h, out_w, input_c), np.uint8)
# Fill the output image, removing the overlap regions to prevent artifacts
output_img[: out_h // 2, : out_w // 2, :] = top_left_rlt[: out_h // 2, : out_w // 2, :]
output_img[: out_h // 2, -out_w // 2 :, :] = top_right_rlt[: out_h // 2, -out_w // 2 :, :]
output_img[-out_h // 2 :, : out_w // 2, :] = bottom_left_rlt[-out_h // 2 :, : out_w // 2, :]
output_img[-out_h // 2 :, -out_w // 2 :, :] = bottom_right_rlt[-out_h // 2 :, -out_w // 2 :, :]
return output_img, discovered_depth
|