Time and Space marching schedule

I would like to implement both Time and Space marching schedule (similarly to what done in the Wave1D example, but unfortunately no code was given, 1D Wave Equation - NVIDIA Docs )

In a post by user @yusuke.takara ( How can I apply the techniques such as temporal-loss weighting and time marching ), it was asked how it is possible to implement such Time Marching schedule.

Moderator @ngeneva suggested to implement a custom Node.

Since I want to pass a custom parametrisation and/or criteria as a function of training epoch’s step, e.g.

time_marching_max_steps = min(cfg.training.max_steps, cfg_data.time_marching_max_steps)
_time_marching_parametrization = lambda step: {
    t_symbol : ( 
        min(_t_i + step / time_marching_max_steps  * _t_f, _t_f) 

or, a more complicated (for space)

space_marching_max_steps = min(cfg.training.max_steps, cfg_data.space_marching_max_steps)
_spat_xy_max = Q_bounds[Q_keys[0]][1]
_moving_radius = lambda step: (
    2*_r +  (_spat_xy_max - 2*_r )* step/space_marching_max_steps 
_elliptic_z_coeff = (_spat_xy_max/_h_tot)**2
z_center = _h_tot/2
_space_marching_criteria = lambda step: And(
     (x - x_center)**2 + (y - y_center)**2 + _elliptic_z_coeff * (z - z_center)**2 <= +( _moving_radius(step) )**2 

I have implemented a custom PointwiseConstraint, extending the Constraint class, by updating criteria/parametrization every epoch before computing losses, e.g.

## here we have the biggest change #####################
    def loss(self, step: int) -> Dict[str, torch.Tensor]:
        Loss function. It is here that the whole change enters.
        We need to update the dataset before calling _loss method.
        if self._output_vars is None:
            logger.warn("Calling loss without forward call")
            return {}
        # call update dataset
        self.update_dataset(step = step) ## <==== HERE ================
        losses = self._loss(

        return losses

It seems to work, but immediately I experienced the terrible “Loss went to nans”.

So my questions are:

  1. Is there a better way to achieve a marching sampling during training?
  2. Is there a way to debug “loss went to nans”

Thanks in advance