credit.datasets.om4_multistep_batcher
=====================================

.. py:module:: credit.datasets.om4_multistep_batcher


Attributes
----------

.. autoapisummary::

   credit.datasets.om4_multistep_batcher.conf


Classes
-------

.. autoapisummary::

   credit.datasets.om4_multistep_batcher.StandardScaler
   credit.datasets.om4_multistep_batcher.Ocean_MultiStep_Batcher
   credit.datasets.om4_multistep_batcher.Ocean_Tensor_Batcher
   credit.datasets.om4_multistep_batcher.Predict_Ocean_Batcher


Functions
---------

.. autoapisummary::

   credit.datasets.om4_multistep_batcher.load_transform


Module Contents
---------------

.. py:function:: load_transform(conf)

   Load data and return the StandardScaler normalization object.
   Only essentials are kept.


.. py:class:: StandardScaler(data_mean: xarray.Dataset, data_std: xarray.Dataset, prognostic_vars: str, boundary_vars: str, wet_mask: torch.Tensor)

   .. py:attribute:: prognostic_mean


   .. py:attribute:: prognostic_std


   .. py:attribute:: boundary_mean


   .. py:attribute:: boundary_std


   .. py:attribute:: wet_mask


   .. py:attribute:: _prognostic_mean_np


   .. py:attribute:: _prognostic_std_np


   .. py:attribute:: _wet_mask_np


   .. py:method:: _to_tensor(array: numpy.ndarray, device: torch.device) -> torch.Tensor

      Convert numpy array to tensor on specified device.



   .. py:method:: normalize_prognostics(data: xarray.Dataset, fill_nan=True, fill_value=0.0) -> xarray.Dataset

      Normalize input dataset.



   .. py:method:: normalize_boundary(data: xarray.Dataset, fill_nan=True, fill_value=0.0) -> xarray.Dataset

      Normalize boundary conditions.



   .. py:method:: unnormalize_prognostics(data: xarray.Dataset) -> xarray.Dataset

      Unnormalize output dataset.



   .. py:method:: normalize_tensor_prognostics(data: torch.Tensor, fill_nan=True, fill_value=0.0) -> torch.Tensor

      Normalize tensor.



   .. py:method:: unnormalize_tensor_prognostics(data: torch.Tensor) -> torch.Tensor

      Unnormalize tensor.



   .. py:method:: normalize_numpy_prognostics(data: numpy.ndarray, fill_nan=True, fill_value=0.0) -> numpy.ndarray

      Normalize numpy array.



   .. py:method:: unnormalize_numpy_prognostics(data: numpy.ndarray) -> numpy.ndarray

      Unnormalize numpy array.



   .. py:method:: transform_array(data: torch.Tensor, fill_nan=True, fill_value=0.0) -> torch.Tensor

      Normalize prognostic variables of a torch tensor.
      Expects shape (B, C, H, W) or (B, C, T, H, W).



   .. py:method:: inverse_transform(data: torch.Tensor) -> torch.Tensor

      Unnormalize prognostic variables of a torch tensor.
      Expects shape (B, C, H, W) or (B, C, T, H, W).



.. py:class:: Ocean_MultiStep_Batcher(conf, seed=42, rank=0, world_size=1, batch_size=1, shuffle=True)

   Bases: :py:obj:`torch.utils.data.Dataset`


   Ocean dataset that handles both single-step and multi-step autoregressive training.
   Returns tensors with shape (batch, channels, time, lat, lon).


   .. py:attribute:: input_length


   .. py:attribute:: output_length


   .. py:attribute:: forecast_len


   .. py:attribute:: seed
      :value: 42



   .. py:attribute:: rank
      :value: 0



   .. py:attribute:: world_size
      :value: 1



   .. py:attribute:: shuffle
      :value: True



   .. py:attribute:: batch_size
      :value: 1



   .. py:attribute:: rng


   .. py:attribute:: _prognostic_data


   .. py:attribute:: _boundary_data


   .. py:attribute:: num_prognostic_vars


   .. py:attribute:: num_boundary_vars


   .. py:attribute:: wet


   .. py:attribute:: wet_surface


   .. py:attribute:: normalize


   .. py:attribute:: size


   .. py:attribute:: sampler


   .. py:attribute:: batch_indices
      :value: None



   .. py:attribute:: time_steps
      :value: None



   .. py:attribute:: forecast_step_counts
      :value: None



   .. py:attribute:: current_epoch
      :value: None



   .. py:method:: initialize_batch()

      Initialize batch indices using DistributedSampler's indices.



   .. py:method:: __len__()


   .. py:method:: set_epoch(epoch)

      Set epoch for distributed training.



   .. py:method:: batches_per_epoch()


   .. py:method:: _get_batch_samples(sample_indices)

      Get multiple ocean samples efficiently by batching xarray operations.
      Returns batch dict with stacked tensors.



   .. py:method:: __getitem__(_)

      Returns batch with tensors shaped (batch, channels, time, lat, lon).
      Optimized for speed.



   .. py:method:: _get_ocean_sample(idx)

      Get single ocean sample with proper time dimensions.
      Returns tensors with shape (channels, time, lat, lon).



.. py:class:: Ocean_Tensor_Batcher(conf, seed=42, rank=0, world_size=1, batch_size=1, shuffle=True)

   Bases: :py:obj:`torch.utils.data.Dataset`


   Ocean dataset that handles both single-step and multi-step autoregressive training.
   Loads cached .pt samples while preserving autoregressive logic.


   .. py:attribute:: wet


   .. py:attribute:: wet_surface


   .. py:attribute:: seed
      :value: 42



   .. py:attribute:: rank
      :value: 0



   .. py:attribute:: world_size
      :value: 1



   .. py:attribute:: shuffle
      :value: True



   .. py:attribute:: batch_size
      :value: 1



   .. py:attribute:: input_length


   .. py:attribute:: output_length


   .. py:attribute:: forecast_len


   .. py:attribute:: samples_dir


   .. py:attribute:: size


   .. py:attribute:: sampler


   .. py:attribute:: batch_indices
      :value: None



   .. py:attribute:: time_steps
      :value: None



   .. py:attribute:: forecast_step_counts
      :value: None



   .. py:attribute:: current_epoch
      :value: None



   .. py:attribute:: num_prognostic_vars


   .. py:method:: initialize_batch()


   .. py:method:: __len__()


   .. py:method:: set_epoch(epoch)


   .. py:method:: batches_per_epoch()


   .. py:method:: _get_batch_samples(sample_indices)

      Load multiple cached .pt samples and concatenate along time dimension.



   .. py:method:: __getitem__(_)

      Returns batch with tensors shaped (batch, channels, time, lat, lon)
      using autoregressive logic.



   .. py:method:: _get_ocean_sample(idx)

      Load a cached sample from disk.



.. py:class:: Predict_Ocean_Batcher(conf, forecast_windows, seed=42, rank=0, world_size=1, batch_size=1, shuffle=True)

   Bases: :py:obj:`Ocean_MultiStep_Batcher`


   Ocean dataset that uses fixed forecast windows instead of continuous sampling.
   Accepts a list of [start, end] datetime string pairs for forecasting.


   .. py:attribute:: valid_indices
      :value: []



   .. py:attribute:: forecast_windows


   .. py:attribute:: size


   .. py:attribute:: forecast_len
      :value: 0



   .. py:attribute:: sampler


   .. py:attribute:: batch_indices


   .. py:attribute:: batch_size


   .. py:attribute:: current_epoch
      :value: 0



   .. py:attribute:: batch_call_count
      :value: 0



   .. py:method:: _convert_windows_to_indices()

      Convert datetime string windows to time indices.



   .. py:method:: _build_valid_indices()

      Build list of valid sample indices from forecast windows.
      Each window gets ONE starting index (the first valid one).



   .. py:method:: __len__()

      Return number of valid samples across all windows.



   .. py:method:: initialize_batch()

      Initialize batch indices from valid indices.



.. py:data:: conf

