import torch
import torch.nn as nn
from typing import Callable, Optional

[docs]class FlattenTime(nn.Flatten): """ Utility layer which always flattens first two dimensions. Meant to convert a tensor of dimensions (Batch, Time, Channels, Height, Width) into a tensor of (Batch*Time, Channels, Height, Width). """ def __init__(self): super().__init__(start_dim=0, end_dim=1)
[docs]class UnflattenTime(nn.Module): """ Utility layer which always unflattens (expands) the first dimension into two separate ones. Meant to convert a tensor of dimensions (Batch*Time, Channels, Height, Width) into a tensor of (Batch, Time, Channels, Height, Width). """ def __init__(self, batch_size: int): super().__init__() self.batch_size = batch_size
[docs] def forward(self, x): num_time_steps = x.shape[0] // self.batch_size return x.unflatten(0, (self.batch_size, num_time_steps))
[docs]class SqueezeMixin: """ Utility mixin class that will wrap the __init__ and forward call of other classes. The wrapped __init__ will provide two additional parameters batch_size and num_timesteps and the wrapped forward will unpack and repack the first dimension into batch and time. """ def squeeze_init(self, batch_size: Optional[int], num_timesteps: Optional[int]): if not batch_size and not num_timesteps: raise TypeError("You need to specify either batch_size or num_timesteps.") if not batch_size: batch_size = -1 if not num_timesteps: num_timesteps = -1 self.batch_size = int(batch_size) self.num_timesteps = int(num_timesteps) def squeeze_forward(self, input_data: torch.Tensor, forward_method: Callable): inflated_input = input_data.reshape( self.batch_size, self.num_timesteps, *input_data.shape[1:] ) inflated_output = forward_method(inflated_input) return inflated_output.flatten(start_dim=0, end_dim=1) def squeeze_param_dict(self, param_dict: dict) -> dict: param_dict.update( batch_size=self.batch_size, num_timesteps=self.num_timesteps, ) return param_dict