credit.trainers.preflight
=========================

.. py:module:: credit.trainers.preflight

.. autoapi-nested-parse::

   preflight.py
   --------------
   Pre-training checks that run before the epoch loop starts.

   The goal is to catch silent hangs and OOM conditions early and emit
   clear, actionable error messages rather than letting jobs hang on the
   cluster for hours.

   Public API
   ----------
   estimate_dataloader_memory_gib(conf) -> float
       Pure function. Computes the expected peak DataLoader CPU RAM footprint
       from trainer and data config. No IO, fully testable.

   check_dataloader_startup(conf, loader, rank, timeout_s) -> None
       Fetches one batch from *loader* with a timeout. Raises RuntimeError
       with a user-friendly message if the fetch hangs or if estimated
       memory looks dangerous.

   check_model_gpu_memory(conf, model, optimizer, rank) -> None
       Runs a synthetic forward/backward/optimizer step and logs peak VRAM.



Attributes
----------

.. autoapisummary::

   credit.trainers.preflight.logger


Functions
---------

.. autoapisummary::

   credit.trainers.preflight.estimate_dataloader_memory_gib
   credit.trainers.preflight._available_ram_gib
   credit.trainers.preflight._fetch_one_batch
   credit.trainers.preflight.check_dataloader_startup
   credit.trainers.preflight.check_model_gpu_memory


Module Contents
---------------

.. py:data:: logger

.. py:function:: estimate_dataloader_memory_gib(conf: dict) -> float

   Estimate peak CPU RAM used by the DataLoader (GiB).

   Formula::

       workers × prefetch_factor × batch_size × sample_bytes

   where sample_bytes = H × W × total_channels × 4 (float32).
   Input and target tensors are counted separately (×2).

   :param conf: Full training config dict.

   :returns: Estimated peak DataLoader RAM in GiB. Returns 0.0 if config is
             missing required keys (non-fatal — estimation is best-effort).


.. py:function:: _available_ram_gib() -> float

   Return available system RAM in GiB, or 0 if psutil is not installed.


.. py:function:: _fetch_one_batch(loader)

   Return the first batch from *loader*, or raise on error.


.. py:function:: check_dataloader_startup(conf: dict, loader, rank: int = 0, timeout_s: float = 300.0) -> None

   Run pre-training data loading checks (rank-0 only).

   1. Logs estimated DataLoader memory and warns if it looks dangerous.
   2. Attempts to fetch the first batch within *timeout_s* seconds.
      Raises RuntimeError with a clear, actionable message if it hangs.

   :param conf: Full training config dict.
   :param loader: Training DataLoader.
   :param rank: Global rank. Checks only run on rank 0.
   :param timeout_s: Seconds to wait for the first batch before failing.


.. py:function:: check_model_gpu_memory(conf: dict, model, optimizer, rank: int = 0) -> None

   Run a synthetic forward/backward/optimizer step and log peak VRAM.

   Creates a zero-filled batch of the expected input shape, runs it through
   the model, backprops, and steps the optimizer. Logs peak VRAM so users
   can verify their model fits on the target GPU before a long training run.

   Input channel count is inferred from the model config:
       frames × (channels × levels + surface_channels + input_only_channels)

   Skips silently if:
     - rank != 0 (only report from rank 0)
     - CUDA is not available
     - input channels cannot be inferred (returns 0)
     - any exception occurs during the synthetic pass

   :param conf: Full training config dict.
   :param model: The model (possibly DDP/FSDP wrapped).
   :param optimizer: The optimizer (used to test a full optimizer step).
   :param rank: Global rank. Check only runs on rank 0.


