|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import Optional, Tuple |
|
from prelude_Block import PreludeBlock |
|
from recurrent_Block import RecurrentBlock |
|
from codaBlock import CodaBlock |
|
|
|
|
|
class LatentRecurrentDepthLM(nn.Module): |
|
def __init__(self, vocab_size: int, d_model: int, num_heads: int, dropout: float = 0.1): |
|
super().__init__() |
|
self.prelude = PreludeBlock(vocab_size, d_model, num_heads, dropout) |
|
self.recurrent = RecurrentBlock(d_model, num_heads, dropout) |
|
self.coda = CodaBlock(d_model, vocab_size) |
|
|
|
def forward(self, x: torch.Tensor, num_iterations: int, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
hidden = self.prelude(x, mask) |
|
recurrent_state = torch.zeros_like(hidden) |
|
for _ in range(num_iterations): |
|
hidden, recurrent_state = self.recurrent(hidden, recurrent_state, mask) |
|
return self.coda(hidden) |