Learning different inintial conditions

Hi, new Modulus user here.

I have been trying to extend the 1D wave equation example in the documentation to learn different initial amplitudes as well. Say learn the range (1, 5) for the initial amplitude. The example stated just uses an amplitude of 1. This is because I eventually want to propagate 2D waves with different initial conditions. I have tried to do this by adding it as an additional input variable. See below

@modulus.sym.main(config_path="conf", config_name="config")
def run(cfg: ModulusConfig) -> None:
    # make list of nodes to unroll graph on
    we = WaveEquationDifferentAmplitude1D(c=1.0)
    wave_net = instantiate_arch(
        input_keys=[Key("x"), Key("t"), Key("a")],
        output_keys=[Key("u")],
        cfg=cfg.arch.fully_connected,
        verbose = True
    )
    nodes = we.make_nodes() + [wave_net.make_node(name="wave_network")]

    # add constraints to solver
    # make geometry
    x, t_symbol = Symbol("x"), Symbol("t")
    a = Symbol("a")
    L = float(np.pi)
    geo = Line1D(0, L)
    time_range = {t_symbol: (0, 2 * L)}

    # make domain
    domain = Domain()

    # initial condition
    IC = PointwiseInteriorConstraint(
        nodes=nodes,
        geometry=geo,
        outvar={"u": a*sin(x), "u__t": a*sin(x)},
        batch_size=cfg.batch_size.IC,
        lambda_weighting={"u": 1.0, "u__t": 1.0},
        parameterization={t_symbol: 0.0},
    )
    domain.add_constraint(IC, "IC")

The rest is the same. But I am getting error

TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and boo

associated with data type of the output variable. I have also played around with changing the wave equation class as well.

I’m assuming their is a much better way to make this simple change?

(Note I have read the documentation of parametrization, but I don’t quite see how it applies.)

I think i solved it by parameterising the amplitude, see below. I also included a way to plot the results.

import numpy as np
import matplotlib.pyplot as plt
from sympy import Symbol, sin

import modulus.sym
from modulus.sym.hydra import instantiate_arch, ModulusConfig
from modulus.sym.solver import Solver
from modulus.sym.domain import Domain
from modulus.sym.geometry import Parameterization
from modulus.sym.geometry.primitives_1d import Line1D
from modulus.sym.domain.constraint import (
    PointwiseBoundaryConstraint,
    PointwiseInteriorConstraint,
)

from modulus.sym.domain.validator import PointwiseValidator
from modulus.sym.domain.inferencer import PointwiseInferencer
from modulus.sym.key import Key
from modulus.sym.node import Node
from wave_equation import WaveEquation1D
from modulus.sym.utils.io import (ValidatorPlotter, InferencerPlotter)

class ValidatorPlotter3D(ValidatorPlotter):
    # GPT-4 kindly gave me this class
    def __call__(self, invar, true_outvar, pred_outvar):
        "Function for plotting 3D validator data"
        
        ndim = len(invar)
        if ndim > 3:
            print("3D plotter can only handle <=3 input dimensions, passing")
            return []
        
        # For 2D data, use the base class method
        if ndim == 2:
            return super().__call__(invar, true_outvar, pred_outvar)

        # For 3D data, handle by plotting slices
        if ndim == 3:
            # Determine the number and indices of slices to plot
            num_slices = 3
            unique_a = np.unique(invar[list(invar.keys())[2]][:, 0])
            idx = np.round(np.linspace(0, len(unique_a) - 1, num_slices)).astype(int)
            dim_slices = unique_a[idx]
            # dim_slices = np.linspace(min(invar[list(invar.keys())[2]][:, 0]), 
            #                          max(invar[list(invar.keys())[2]][:, 0]), 
            #                          num_slices)

            fs = []
            for slice_val in dim_slices:
                # Extract the slice
                mask = np.isin(invar[list(invar.keys())[2]][:, 0], slice_val)
                invar_slice = {k: v[mask] for k, v in invar.items()}
                true_outvar_slice = {k: v[mask] for k, v in true_outvar.items()}
                pred_outvar_slice = {k: v[mask] for k, v in pred_outvar.items()}

                # Use _interpolate_2D to interpolate 2D data for this slice
                extent, true_outvar_interp, pred_outvar_interp = self._interpolate_2D(
                    100, {k: invar_slice[k][:, :2] for k in invar_slice if k not in list(invar.keys())[2]}, true_outvar_slice, pred_outvar_slice)

                # Plot the data for this slice
                for k in pred_outvar_interp:
                    f = plt.figure(figsize=(3 * 5, 4), dpi=100)
                    for i, (o, tag) in enumerate(
                        zip(
                            [true_outvar_interp[k], pred_outvar_interp[k], true_outvar_interp[k] - pred_outvar_interp[k]],
                            ["true", "pred", "diff"],
                        )
                    ):
                        plt.subplot(1, 3, 1 + i)
                        plt.imshow(o.T, origin="lower", extent=extent)
                        plt.xlabel(list(invar.keys())[0])
                        plt.ylabel(list(invar.keys())[1])
                        plt.colorbar()
                        plt.title(f"{k}_{tag} (slice at {list(invar.keys())[2]}={slice_val:.2f})")
                    plt.tight_layout()
                    fs.append((f, f"{k}_slice_at_{list(invar.keys())[2]}_{slice_val:.2f}"))

            return fs

        return []


@modulus.sym.main(config_path="conf", config_name="config")
def run(cfg: ModulusConfig) -> None:
    # make list of nodes to unroll graph on
    we = WaveEquation1D(c=1.0)
    wave_net = instantiate_arch(
        input_keys=[Key("x"), Key("t"), Key("a")],
        output_keys=[Key("u")],
        cfg=cfg.arch.fully_connected
    )

    nodes = we.make_nodes() + [wave_net.make_node(name="wave_network")]

    # add constraints to solver
    # make geometry
    x, t_symbol = Symbol("x"), Symbol("t")
    a_symbol = Symbol("a")
    L = float(np.pi)
    geo = Line1D(0, L)
    time_range = (0, 2 * L)
    amp_range = (1, 3)

    time_param = {t_symbol: time_range}
    param = Parameterization(
        {t_symbol: time_range, a_symbol: amp_range}
    )

    param_initial = Parameterization(
        {t_symbol: 0.0, a_symbol: amp_range}
    )

    # make domain
    domain = Domain()

    # initial condition
    IC = PointwiseInteriorConstraint(
        nodes=nodes,
        geometry=geo,
        outvar={"u": a_symbol * sin(x), "u__t": a_symbol * sin(x)},
        batch_size=cfg.batch_size.IC,
        lambda_weighting={"u": 1.0, "u__t": 1.0},
        parameterization=param_initial,
    )
    domain.add_constraint(IC, "IC")

    # boundary condition
    BC = PointwiseBoundaryConstraint(
        nodes=nodes,
        geometry=geo,
        outvar={"u": 0},
        batch_size=cfg.batch_size.BC,
        parameterization=param,
    )
    domain.add_constraint(BC, "BC")

    # interior
    interior = PointwiseInteriorConstraint(
        nodes=nodes,
        geometry=geo,
        outvar={"wave_equation": 0},
        batch_size=cfg.batch_size.interior,
        parameterization=param,
    )
    domain.add_constraint(interior, "interior")

    deltaT = 0.01
    deltaX = 0.01
    deltaa = .5
    x = np.arange(0, L, deltaX)
    t = np.arange(0, 2 * L, deltaT)
    a_numpy = np.arange(1, 3, deltaa)
    # X, T, A = [each.ravel(order='F') for each in np.meshgrid(x, t, a)]
    X, T, A = np.meshgrid(x, t, a_numpy)
    X = np.expand_dims(X.flatten(), axis=-1)
    T = np.expand_dims(T.flatten(), axis=-1)
    A = np.expand_dims(A.flatten(), axis=-1)
    u = A * (np.sin(X) * (np.cos(T) + np.sin(T)))
    # a = A * np.sin(X) * (np.cos(T) + np.sin(T))
    invar_numpy = {"x": X, "t": T, "a": A}
    outvar_numpy = {"u": u}

    validator = PointwiseValidator(
        nodes=nodes,
        invar=invar_numpy,
        true_outvar=outvar_numpy,
        batch_size=1024,
        plotter=ValidatorPlotter3D()
    )
    domain.add_validator(validator)
    
    # make solver
    slv = Solver(cfg, domain)

    # start solver
    slv.solve()


if __name__ == "__main__":
    run()