credit.preblock.base#

Classes#

BasePreblock

Base class for all preblocks. Enforces the forward signature and

Module Contents#

class credit.preblock.base.BasePreblock(*args: Any, **kwargs: Any)#

Bases: torch.nn.Module

Base class for all preblocks. Enforces the forward signature and provides the from_config classmethod used by the registry.

VALID_DATA_TYPES = ('input', 'target')#
_copy_batch(batch: dict) dict#

Shallow-copy the batch so forward() never mutates the caller’s dict.

Creates new dict objects at the data_type and source levels; tensor values are shared (not copied) since preblocks must not mutate tensors in-place anyway (doing so would break autograd).

forward(batch: dict) dict#
classmethod from_config(**kwargs)#