avans06 commited on
Commit
aa7cfca
·
1 Parent(s): 129b135

feat(dataops): Proactively split large tiles in auto_split_upscale to prevent CUDA OOM errors.

Browse files
Files changed (1) hide show
  1. utils/dataops.py +91 -29
utils/dataops.py CHANGED
@@ -37,70 +37,132 @@ def auto_split_upscale(
37
  upscale_function,
38
  scale: int = 4,
39
  overlap: int = 32,
40
- max_depth: int = None,
 
 
 
 
 
41
  current_depth: int = 1,
42
  current_tile: int = 1, # Tracks the current tile being processed
43
  total_tiles: int = 1, # Total number of tiles at this depth level
44
  ):
45
- # Attempt to upscale if unknown depth or if reached known max depth
46
- if max_depth is None or max_depth == current_depth:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  try:
48
  print(f"auto_split_upscale depth: {current_depth}", end=" ", flush=True)
49
  result, _ = upscale_function(lr_img, scale)
 
50
  print(f"progress: {current_tile}/{total_tiles}")
 
51
  return result, current_depth
52
  except RuntimeError as e:
53
  # Check to see if its actually the CUDA out of memory error
54
  if "CUDA" in str(e):
 
55
  print("RuntimeError: CUDA out of memory...")
56
- # Re-raise the exception if not an OOM error
 
 
57
  else:
 
58
  raise RuntimeError(e)
59
- # Collect garbage (clear VRAM)
60
- torch.cuda.empty_cache()
61
- gc.collect()
62
-
63
- input_h, input_w, input_c = lr_img.shape
 
 
 
 
 
 
 
64
 
65
- # Split the image into 4 quadrants with some overlap
 
 
 
66
  top_left = lr_img[: input_h // 2 + overlap, : input_w // 2 + overlap, :]
67
  top_right = lr_img[: input_h // 2 + overlap, input_w // 2 - overlap :, :]
68
  bottom_left = lr_img[input_h // 2 - overlap :, : input_w // 2 + overlap, :]
69
  bottom_right = lr_img[input_h // 2 - overlap :, input_w // 2 - overlap :, :]
70
- current_depth = current_depth + 1
71
- current_tile = (current_tile - 1) * 4
72
- total_tiles = total_tiles * 4
73
-
74
- # Recursively upscale each quadrant and track the current tile number
75
- # After we go through the top left quadrant, we know the maximum depth and no longer need to test for out-of-memory
76
- top_left_rlt, depth = auto_split_upscale(
77
- top_left, upscale_function, scale=scale, overlap=overlap, max_depth=max_depth,
78
- current_depth=current_depth, current_tile=current_tile + 1, total_tiles=total_tiles,
 
 
 
79
  )
 
80
  top_right_rlt, _ = auto_split_upscale(
81
- top_right, upscale_function, scale=scale, overlap=overlap, max_depth=depth,
82
- current_depth=current_depth, current_tile=current_tile + 2, total_tiles=total_tiles,
 
 
 
 
83
  )
84
  bottom_left_rlt, _ = auto_split_upscale(
85
- bottom_left, upscale_function, scale=scale, overlap=overlap, max_depth=depth,
86
- current_depth=current_depth, current_tile=current_tile + 3, total_tiles=total_tiles,
 
 
 
 
87
  )
88
  bottom_right_rlt, _ = auto_split_upscale(
89
- bottom_right, upscale_function, scale=scale, overlap=overlap, max_depth=depth,
90
- current_depth=current_depth, current_tile=current_tile + 4, total_tiles=total_tiles,
 
 
 
 
91
  )
92
 
93
- # Define the output image size
 
94
  out_h = input_h * scale
95
  out_w = input_w * scale
96
 
97
  # Create an empty output image
98
  output_img = np.zeros((out_h, out_w, input_c), np.uint8)
99
-
100
- # Fill the output image with the upscaled quadrants, removing overlap regions
101
  output_img[: out_h // 2, : out_w // 2, :] = top_left_rlt[: out_h // 2, : out_w // 2, :]
102
  output_img[: out_h // 2, -out_w // 2 :, :] = top_right_rlt[: out_h // 2, -out_w // 2 :, :]
103
  output_img[-out_h // 2 :, : out_w // 2, :] = bottom_left_rlt[-out_h // 2 :, : out_w // 2, :]
104
  output_img[-out_h // 2 :, -out_w // 2 :, :] = bottom_right_rlt[-out_h // 2 :, -out_w // 2 :, :]
105
 
106
- return output_img, depth
 
37
  upscale_function,
38
  scale: int = 4,
39
  overlap: int = 32,
40
+ # A heuristic to proactively split tiles that are too large, avoiding a CUDA error.
41
+ # The default (2048*2048) is a conservative value for moderate VRAM (e.g., 8-12GB).
42
+ # Adjust this based on your GPU and model's memory footprint.
43
+ max_tile_pixels: int = 4194304, # Default: 2048 * 2048 pixels
44
+ # Internal parameters for recursion state. Do not set these manually.
45
+ known_max_depth: int = None,
46
  current_depth: int = 1,
47
  current_tile: int = 1, # Tracks the current tile being processed
48
  total_tiles: int = 1, # Total number of tiles at this depth level
49
  ):
50
+ # --- Step 0: Handle CPU-only environment ---
51
+ # The entire splitting logic is designed to overcome GPU VRAM limitations.
52
+ # If no CUDA-enabled GPU is present, this logic is unnecessary and adds overhead.
53
+ # Therefore, we process the image in one go on the CPU.
54
+ if not torch.cuda.is_available():
55
+ # Note: This assumes the image fits into system RAM, which is usually the case.
56
+ result, _ = upscale_function(lr_img, scale)
57
+ # The conceptual depth is 1 since no splitting was performed.
58
+ return result, 1
59
+
60
+ """
61
+ Automatically splits an image into tiles for upscaling to avoid CUDA out-of-memory errors.
62
+ It uses a combination of a pixel-count heuristic and reactive error handling to find the
63
+ optimal processing depth, then applies this depth to all subsequent tiles.
64
+ """
65
+ input_h, input_w, input_c = lr_img.shape
66
+
67
+ # --- Step 1: Decide if we should ATTEMPT to upscale or MUST split ---
68
+ # We must split if:
69
+ # A) The tile is too large based on our heuristic, and we don't have a known working depth yet.
70
+ # B) We have a known working depth from a sibling tile, but we haven't recursed deep enough to reach it yet.
71
+ must_split = (known_max_depth is None and (input_h * input_w) > max_tile_pixels) or \
72
+ (known_max_depth is not None and current_depth < known_max_depth)
73
+
74
+ if not must_split:
75
+ # If we are not forced to split, let's try to upscale the current tile.
76
  try:
77
  print(f"auto_split_upscale depth: {current_depth}", end=" ", flush=True)
78
  result, _ = upscale_function(lr_img, scale)
79
+ # SUCCESS! The upscale worked at this depth.
80
  print(f"progress: {current_tile}/{total_tiles}")
81
+ # Return the result and the current depth, which is now the "known_max_depth".
82
  return result, current_depth
83
  except RuntimeError as e:
84
  # Check to see if its actually the CUDA out of memory error
85
  if "CUDA" in str(e):
86
+ # OOM ERROR. Our heuristic was too optimistic. This depth is not viable.
87
  print("RuntimeError: CUDA out of memory...")
88
+ # Clean up VRAM and proceed to the splitting logic below.
89
+ torch.cuda.empty_cache()
90
+ gc.collect()
91
  else:
92
+ # A different runtime error occurred, so we should not suppress it.
93
  raise RuntimeError(e)
94
+ # If an OOM error occurred, flow continues to the splitting section.
95
+
96
+ # --- Step 2: If we reached here, we MUST split the image ---
97
+
98
+ # Safety break to prevent infinite recursion if something goes wrong.
99
+ if current_depth > 10:
100
+ raise RuntimeError("Maximum recursion depth exceeded. Check max_tile_pixels or model requirements.")
101
+
102
+ # Prepare parameters for the next level of recursion.
103
+ next_depth = current_depth + 1
104
+ new_total_tiles = total_tiles * 4
105
+ base_tile_for_next_level = (current_tile - 1) * 4
106
 
107
+ # Announce the split only when it's happening.
108
+ print(f"Splitting tile at depth {current_depth} into 4 tiles for depth {next_depth}.")
109
+
110
+ # Split the image into 4 quadrants with overlap.
111
  top_left = lr_img[: input_h // 2 + overlap, : input_w // 2 + overlap, :]
112
  top_right = lr_img[: input_h // 2 + overlap, input_w // 2 - overlap :, :]
113
  bottom_left = lr_img[input_h // 2 - overlap :, : input_w // 2 + overlap, :]
114
  bottom_right = lr_img[input_h // 2 - overlap :, input_w // 2 - overlap :, :]
115
+
116
+ # Recursively process each quadrant.
117
+ # Process the first quadrant to discover the safe depth.
118
+ # The first quadrant (top_left) will "discover" the correct processing depth.
119
+ # Pass the current `known_max_depth` down.
120
+ top_left_rlt, discovered_depth = auto_split_upscale(
121
+ top_left, upscale_function, scale=scale, overlap=overlap,
122
+ max_tile_pixels=max_tile_pixels,
123
+ known_max_depth=known_max_depth,
124
+ current_depth=next_depth,
125
+ current_tile=base_tile_for_next_level + 1,
126
+ total_tiles=new_total_tiles,
127
  )
128
+ # Once the depth is discovered, pass it to the other quadrants to avoid redundant checks.
129
  top_right_rlt, _ = auto_split_upscale(
130
+ top_right, upscale_function, scale=scale, overlap=overlap,
131
+ max_tile_pixels=max_tile_pixels,
132
+ known_max_depth=discovered_depth,
133
+ current_depth=next_depth,
134
+ current_tile=base_tile_for_next_level + 2,
135
+ total_tiles=new_total_tiles,
136
  )
137
  bottom_left_rlt, _ = auto_split_upscale(
138
+ bottom_left, upscale_function, scale=scale, overlap=overlap,
139
+ max_tile_pixels=max_tile_pixels,
140
+ known_max_depth=discovered_depth,
141
+ current_depth=next_depth,
142
+ current_tile=base_tile_for_next_level + 3,
143
+ total_tiles=new_total_tiles,
144
  )
145
  bottom_right_rlt, _ = auto_split_upscale(
146
+ bottom_right, upscale_function, scale=scale, overlap=overlap,
147
+ max_tile_pixels=max_tile_pixels,
148
+ known_max_depth=discovered_depth,
149
+ current_depth=next_depth,
150
+ current_tile=base_tile_for_next_level + 4,
151
+ total_tiles=new_total_tiles,
152
  )
153
 
154
+ # --- Step 3: Stitch the results back together ---
155
+ # Reassemble the upscaled quadrants into a single image.
156
  out_h = input_h * scale
157
  out_w = input_w * scale
158
 
159
  # Create an empty output image
160
  output_img = np.zeros((out_h, out_w, input_c), np.uint8)
161
+
162
+ # Fill the output image, removing the overlap regions to prevent artifacts
163
  output_img[: out_h // 2, : out_w // 2, :] = top_left_rlt[: out_h // 2, : out_w // 2, :]
164
  output_img[: out_h // 2, -out_w // 2 :, :] = top_right_rlt[: out_h // 2, -out_w // 2 :, :]
165
  output_img[-out_h // 2 :, : out_w // 2, :] = bottom_left_rlt[-out_h // 2 :, : out_w // 2, :]
166
  output_img[-out_h // 2 :, -out_w // 2 :, :] = bottom_right_rlt[-out_h // 2 :, -out_w // 2 :, :]
167
 
168
+ return output_img, discovered_depth