credit.preblock.concat#
Classes#
End-of-chain preblock that concatenates a nested batch dict of tensors |
Module Contents#
- class credit.preblock.concat.ConcatToTensor(*args: Any, **kwargs: Any)#
Bases:
credit.preblock.base.BasePreblockEnd-of-chain preblock that concatenates a nested batch dict of tensors into a single input tensor (and optionally a target tensor).
Expects a batch dict of the form:
batch[source][data_type][var_name] -> torch.Tensor
where tensor shapes are (batch, channel, time, lon, lat) and concatenation is performed along dim=1 (channel). Traversal order follows key insertion order: for each source, all var_names under a data_type are concatenated, then the next source, and so on.
metadatakeys are passed through as-is (not concatenated).Returns either:
(input_tensor, metadata) # if no "target" data_type present (input_tensor, target_tensor, metadata) # if "target" is present
Example config:
type: "concatenate_to_tensor" args: {}
- forward(batch)#