credit.preblock.concat#

Classes#

ConcatToTensor

End-of-chain preblock that concatenates a nested batch dict of tensors

Module Contents#

class credit.preblock.concat.ConcatToTensor(*args: Any, **kwargs: Any)#

Bases: credit.preblock.base.BasePreblock

End-of-chain preblock that concatenates a nested batch dict of tensors into a single input tensor (and optionally a target tensor).

Expects a batch dict of the form:

batch[source][data_type][var_name] -> torch.Tensor

where tensor shapes are (batch, channel, time, lon, lat) and concatenation is performed along dim=1 (channel). Traversal order follows key insertion order: for each source, all var_names under a data_type are concatenated, then the next source, and so on.

metadata keys are passed through as-is (not concatenated).

Returns either:

(input_tensor, metadata)                    # if no "target" data_type present
(input_tensor, target_tensor, metadata)     # if "target" is present

Example config:

type: "concatenate_to_tensor"
args: {}
forward(batch)#