|
from typing import Callable, Optional |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
def ordered_halving(val): |
|
bin_str = f"{val:064b}" |
|
bin_flip = bin_str[::-1] |
|
as_int = int(bin_flip, 2) |
|
|
|
return as_int / (1 << 64) |
|
|
|
|
|
|
|
def uniform( |
|
step: int = ..., |
|
num_steps: Optional[int] = None, |
|
num_frames: int = ..., |
|
context_size: Optional[int] = None, |
|
context_stride: int = 3, |
|
context_overlap: int = 4, |
|
closed_loop: bool = True, |
|
): |
|
if num_frames <= context_size: |
|
yield list(range(num_frames)) |
|
return |
|
|
|
context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) |
|
|
|
for context_step in 1 << np.arange(context_stride): |
|
pad = int(round(num_frames * ordered_halving(step))) |
|
for j in range( |
|
int(ordered_halving(step) * context_step) + pad, |
|
num_frames + pad + (0 if closed_loop else -context_overlap), |
|
(context_size * context_step - context_overlap), |
|
): |
|
yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] |
|
|
|
|
|
def shuffle( |
|
step: int = ..., |
|
num_steps: Optional[int] = None, |
|
num_frames: int = ..., |
|
context_size: Optional[int] = None, |
|
context_stride: int = 3, |
|
context_overlap: int = 4, |
|
closed_loop: bool = True, |
|
): |
|
import random |
|
c = list(range(num_frames)) |
|
c = random.sample(c, len(c)) |
|
|
|
if len(c) % context_size: |
|
c += c[0:context_size - len(c) % context_size] |
|
|
|
c = random.sample(c, len(c)) |
|
|
|
for i in range(0, len(c), context_size): |
|
yield c[i:i+context_size] |
|
|
|
|
|
def composite( |
|
step: int = ..., |
|
num_steps: Optional[int] = None, |
|
num_frames: int = ..., |
|
context_size: Optional[int] = None, |
|
context_stride: int = 3, |
|
context_overlap: int = 4, |
|
closed_loop: bool = True, |
|
): |
|
if (step/num_steps) < 0.1: |
|
return shuffle(step,num_steps,num_frames,context_size,context_stride,context_overlap,closed_loop) |
|
else: |
|
return uniform(step,num_steps,num_frames,context_size,context_stride,context_overlap,closed_loop) |
|
|
|
|
|
def get_context_scheduler(name: str) -> Callable: |
|
match name: |
|
case "uniform": |
|
return uniform |
|
case "shuffle": |
|
return shuffle |
|
case "composite": |
|
return composite |
|
case _: |
|
raise ValueError(f"Unknown context_overlap policy {name}") |
|
|
|
|
|
def get_total_steps( |
|
scheduler, |
|
timesteps: list[int], |
|
num_steps: Optional[int] = None, |
|
num_frames: int = ..., |
|
context_size: Optional[int] = None, |
|
context_stride: int = 3, |
|
context_overlap: int = 4, |
|
closed_loop: bool = True, |
|
): |
|
return sum( |
|
len( |
|
list( |
|
scheduler( |
|
i, |
|
num_steps, |
|
num_frames, |
|
context_size, |
|
context_stride, |
|
context_overlap, |
|
) |
|
) |
|
) |
|
for i in range(len(timesteps)) |
|
) |
|
|