|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Streaming module API that should be implemented by all Streaming components, |
|
""" |
|
|
|
import abc |
|
from contextlib import contextmanager |
|
from dataclasses import dataclass |
|
import itertools |
|
import math |
|
import typing as tp |
|
from torch import nn |
|
import torch |
|
|
|
|
|
class Resetable(tp.Protocol): |
|
def reset(self) -> None: |
|
pass |
|
|
|
|
|
State = tp.TypeVar("State", bound=Resetable) |
|
|
|
|
|
class StreamingModule(abc.ABC, nn.Module, tp.Generic[State]): |
|
"""Common API for streaming components. |
|
|
|
Each streaming component has a streaming state, which is just a dict[str, Tensor]. |
|
By convention, the first dim of each tensor must be the batch size. |
|
Don't use dots in the key names, as this would clash with submodules |
|
(like in state_dict). |
|
|
|
If `self._is_streaming` is True, the component should use and remember |
|
the proper state inside `self._streaming_state`. |
|
|
|
To set a streaming component in streaming state, use |
|
|
|
with module.streaming(): |
|
... |
|
|
|
This will automatically reset the streaming state when exiting the context manager. |
|
This also automatically propagates to all streaming children module. |
|
|
|
Some module might also implement the `StreamingModule.flush` method, although |
|
this one is trickier, as all parents module must be StreamingModule and implement |
|
it as well for it to work properly. See `StreamingSequential` after. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
self._streaming_state: State | None = None |
|
self._streaming_propagate: bool = True |
|
|
|
@property |
|
def is_streaming(self): |
|
return self._streaming_state is not None |
|
|
|
def set_streaming_propagate(self, streaming_propagate: bool): |
|
self._streaming_propagate = streaming_propagate |
|
|
|
def _apply_named_streaming(self, fn: tp.Any): |
|
def _handle_module(prefix: str, module: nn.Module, recurse: bool = True): |
|
propagate = True |
|
if isinstance(module, StreamingModule): |
|
if module._streaming_propagate: |
|
fn(prefix, module) |
|
else: |
|
propagate = False |
|
if not recurse: |
|
return |
|
if propagate: |
|
for name, child in module.named_children(): |
|
_handle_module(prefix + "." + name, child) |
|
|
|
_handle_module("", self, recurse=False) |
|
for name, child in self.named_children(): |
|
_handle_module(name, child) |
|
|
|
def _start_streaming(self, batch_size: int): |
|
def _start_streaming(name: str, module: StreamingModule): |
|
module._streaming_state = module._init_streaming_state(batch_size) |
|
|
|
self._apply_named_streaming(_start_streaming) |
|
|
|
def _stop_streaming(self): |
|
def _stop_streaming(name: str, module: StreamingModule): |
|
module._streaming_state = None |
|
|
|
self._apply_named_streaming(_stop_streaming) |
|
|
|
@abc.abstractmethod |
|
def _init_streaming_state(self, batch_size: int) -> State: ... |
|
|
|
def streaming_forever(self, batch_size: int): |
|
self._start_streaming(batch_size) |
|
|
|
@contextmanager |
|
def streaming(self, batch_size: int): |
|
"""Context manager to enter streaming mode. Reset streaming state on exit.""" |
|
|
|
self._start_streaming(batch_size) |
|
try: |
|
yield |
|
finally: |
|
self._stop_streaming() |
|
|
|
def reset_streaming(self): |
|
"""Reset the streaming state.""" |
|
|
|
def _reset(name: str, module: StreamingModule): |
|
state = module._streaming_state |
|
if state is None: |
|
raise ValueError( |
|
f"Trying to reset streaming, but {name} wasn't streaming." |
|
) |
|
state.reset() |
|
|
|
self._apply_named_streaming(_reset) |
|
|
|
def get_streaming_state(self) -> dict[str, tp.Any]: |
|
"""Return the complete streaming state, including that of sub-modules.""" |
|
state: dict[str, tp.Any] = {} |
|
|
|
def _add(name: str, module: StreamingModule): |
|
state[name] = module._streaming_state |
|
|
|
self._apply_named_streaming(_add) |
|
return state |
|
|
|
def set_streaming_state(self, state: dict[str, tp.Any]): |
|
"""Set the streaming state, including that of sub-modules.""" |
|
state = dict(state) |
|
|
|
def _set(name: str, module: StreamingModule): |
|
if name in state: |
|
module._streaming_state = state[name] |
|
state.pop(name) |
|
else: |
|
raise RuntimeError(f"Expected to find a streaming state for {name}.") |
|
|
|
self._apply_named_streaming(_set) |
|
if state: |
|
raise RuntimeError(f"Some states were not consumed: {list(state.keys())}") |
|
|
|
|
|
@dataclass |
|
class _NullState: |
|
pass |
|
|
|
def reset(self) -> None: |
|
pass |
|
|
|
|
|
class StreamingContainer(StreamingModule[_NullState]): |
|
def _init_streaming_state(self, batch_size: int) -> _NullState: |
|
return _NullState() |
|
|
|
|
|
@dataclass |
|
class _StreamingAddState: |
|
previous_x: torch.Tensor | None = None |
|
previous_y: torch.Tensor | None = None |
|
|
|
def reset(self): |
|
self.previous_x = None |
|
self.previous_y = None |
|
|
|
|
|
class StreamingAdd(StreamingModule[_StreamingAddState]): |
|
def _init_streaming_state(self, batch_size: int) -> _StreamingAddState: |
|
return _StreamingAddState() |
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor): |
|
if self._streaming_state is None: |
|
return x + y |
|
else: |
|
prev_x = self._streaming_state.previous_x |
|
prev_y = self._streaming_state.previous_y |
|
if prev_x is not None: |
|
x = torch.cat([prev_x, x], dim=-1) |
|
if prev_y is not None: |
|
y = torch.cat([prev_y, y], dim=-1) |
|
m_l = min(x.shape[-1], y.shape[-1]) |
|
self._streaming_state.previous_x = x[..., m_l:] |
|
self._streaming_state.previous_y = y[..., m_l:] |
|
return x[..., :m_l] + y[..., :m_l] |
|
|
|
|
|
@dataclass |
|
class _StreamingConvState: |
|
previous: torch.Tensor | None = None |
|
|
|
def reset(self): |
|
self.previous = None |
|
|
|
|
|
class RawStreamingConv1d(nn.Conv1d, StreamingModule[_StreamingConvState]): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
assert self.padding[0] == 0, "Padding should be handled outside." |
|
assert ( |
|
self.stride[0] <= self.kernel_size[0] |
|
), "stride must be less than kernel_size." |
|
|
|
def _init_streaming_state(self, batch_size: int) -> _StreamingConvState: |
|
return _StreamingConvState() |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
stride = self.stride[0] |
|
|
|
kernel = (self.kernel_size[0] - 1) * self.dilation[0] + 1 |
|
if self._streaming_state is None: |
|
return super().forward(input) |
|
else: |
|
|
|
previous = self._streaming_state.previous |
|
if previous is not None: |
|
input = torch.cat([previous, input], dim=-1) |
|
B, C, T = input.shape |
|
|
|
|
|
num_frames = max(0, int(math.floor((T - kernel) / stride) + 1)) |
|
offset = num_frames * stride |
|
|
|
|
|
|
|
self._streaming_state.previous = input[..., offset:] |
|
if num_frames > 0: |
|
input_length = (num_frames - 1) * stride + kernel |
|
out = super().forward(input[..., :input_length]) |
|
else: |
|
|
|
out = torch.empty( |
|
B, self.out_channels, 0, device=input.device, dtype=input.dtype |
|
) |
|
return out |
|
|
|
|
|
@dataclass |
|
class _StreamingConvTrState: |
|
partial: torch.Tensor | None = None |
|
|
|
def reset(self): |
|
self.partial = None |
|
|
|
|
|
class RawStreamingConvTranspose1d( |
|
nn.ConvTranspose1d, StreamingModule[_StreamingConvTrState] |
|
): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
assert self.padding[0] == 0, "Padding should be handled outside." |
|
assert self.dilation[0] == 1, "No dilation for now" |
|
assert ( |
|
self.stride[0] <= self.kernel_size[0] |
|
), "stride must be less than kernel_size." |
|
assert self.output_padding[0] == 0, "Output padding not supported." |
|
|
|
def _init_streaming_state(self, batch_size: int) -> _StreamingConvTrState: |
|
return _StreamingConvTrState() |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
B, C, T = x.shape |
|
stride = self.stride[0] |
|
kernel = self.kernel_size[0] |
|
if self._streaming_state is None: |
|
return super().forward(x) |
|
else: |
|
if T == 0: |
|
return torch.empty( |
|
B, self.out_channels, 0, device=x.device, dtype=x.dtype |
|
) |
|
out = super().forward(x) |
|
OT = out.shape[-1] |
|
partial = self._streaming_state.partial |
|
if partial is not None: |
|
|
|
|
|
|
|
|
|
|
|
PT = partial.shape[-1] |
|
if self.bias is not None: |
|
out[..., :PT] += partial - self.bias[:, None] |
|
else: |
|
out[..., :PT] += partial |
|
|
|
|
|
|
|
|
|
invalid_steps = kernel - stride |
|
partial = out[..., OT - invalid_steps :] |
|
out = out[..., : OT - invalid_steps] |
|
self._streaming_state.partial = partial |
|
return out |
|
|
|
|
|
def test(): |
|
torch.manual_seed(1234) |
|
device = "cpu" |
|
|
|
kernel_sizes = [1, 3, 4, 8, 15, 16] |
|
strides = [1, 2, 3, 4, 5, 6, 7, 8, 9] |
|
chin = 6 |
|
chout = 12 |
|
|
|
for kernel, stride in itertools.product(kernel_sizes, strides): |
|
if stride > kernel: |
|
continue |
|
conv = RawStreamingConv1d(chin, chout, kernel, stride).to(device) |
|
convtr = RawStreamingConvTranspose1d(chout, chin, kernel, stride).to(device) |
|
|
|
for length in [4, 8, 32, 54, 65, 128, 1043]: |
|
print(f"ksize {kernel} strides {stride} len {length}") |
|
if length < kernel: |
|
continue |
|
batch_size = 3 |
|
x = torch.randn(batch_size, chin, length).to(device) |
|
y = conv(x) |
|
z = convtr(y) |
|
for chunk_size in [1, 3, 5, 8]: |
|
ys = [] |
|
zs = [] |
|
with conv.streaming(batch_size), convtr.streaming(batch_size): |
|
for offset in range(0, length, chunk_size): |
|
chunk = x[..., offset : offset + chunk_size] |
|
ys.append(conv(chunk)) |
|
zs.append(convtr(ys[-1])) |
|
y_stream = torch.cat(ys, dim=-1) |
|
z_stream = torch.cat(zs, dim=-1) |
|
y = y[..., : y_stream.shape[-1]] |
|
z = z[..., : z_stream.shape[-1]] |
|
assert y.shape == y_stream.shape, (y.shape, y_stream.shape) |
|
delta = (y_stream - y).norm() / y.norm() |
|
assert delta <= 1e-6, delta |
|
num_frames = int((length - kernel) / stride) + 1 |
|
assert num_frames == y_stream.shape[-1] |
|
|
|
assert z.shape == z_stream.shape, (z.shape, z_stream.shape) |
|
delta = (z_stream - z).norm() / z.norm() |
|
assert delta <= 1e-6, (delta, (z_stream - z).abs().mean(dim=(0, 1))) |
|
|
|
|
|
if __name__ == "__main__": |
|
with torch.no_grad(): |
|
test() |
|
|