I think i solved it by parameterising the amplitude, see below. I also included a way to plot the results.
import numpy as np
import matplotlib.pyplot as plt
from sympy import Symbol, sin
import modulus.sym
from modulus.sym.hydra import instantiate_arch, ModulusConfig
from modulus.sym.solver import Solver
from modulus.sym.domain import Domain
from modulus.sym.geometry import Parameterization
from modulus.sym.geometry.primitives_1d import Line1D
from modulus.sym.domain.constraint import (
PointwiseBoundaryConstraint,
PointwiseInteriorConstraint,
)
from modulus.sym.domain.validator import PointwiseValidator
from modulus.sym.domain.inferencer import PointwiseInferencer
from modulus.sym.key import Key
from modulus.sym.node import Node
from wave_equation import WaveEquation1D
from modulus.sym.utils.io import (ValidatorPlotter, InferencerPlotter)
class ValidatorPlotter3D(ValidatorPlotter):
# GPT-4 kindly gave me this class
def __call__(self, invar, true_outvar, pred_outvar):
"Function for plotting 3D validator data"
ndim = len(invar)
if ndim > 3:
print("3D plotter can only handle <=3 input dimensions, passing")
return []
# For 2D data, use the base class method
if ndim == 2:
return super().__call__(invar, true_outvar, pred_outvar)
# For 3D data, handle by plotting slices
if ndim == 3:
# Determine the number and indices of slices to plot
num_slices = 3
unique_a = np.unique(invar[list(invar.keys())[2]][:, 0])
idx = np.round(np.linspace(0, len(unique_a) - 1, num_slices)).astype(int)
dim_slices = unique_a[idx]
# dim_slices = np.linspace(min(invar[list(invar.keys())[2]][:, 0]),
# max(invar[list(invar.keys())[2]][:, 0]),
# num_slices)
fs = []
for slice_val in dim_slices:
# Extract the slice
mask = np.isin(invar[list(invar.keys())[2]][:, 0], slice_val)
invar_slice = {k: v[mask] for k, v in invar.items()}
true_outvar_slice = {k: v[mask] for k, v in true_outvar.items()}
pred_outvar_slice = {k: v[mask] for k, v in pred_outvar.items()}
# Use _interpolate_2D to interpolate 2D data for this slice
extent, true_outvar_interp, pred_outvar_interp = self._interpolate_2D(
100, {k: invar_slice[k][:, :2] for k in invar_slice if k not in list(invar.keys())[2]}, true_outvar_slice, pred_outvar_slice)
# Plot the data for this slice
for k in pred_outvar_interp:
f = plt.figure(figsize=(3 * 5, 4), dpi=100)
for i, (o, tag) in enumerate(
zip(
[true_outvar_interp[k], pred_outvar_interp[k], true_outvar_interp[k] - pred_outvar_interp[k]],
["true", "pred", "diff"],
)
):
plt.subplot(1, 3, 1 + i)
plt.imshow(o.T, origin="lower", extent=extent)
plt.xlabel(list(invar.keys())[0])
plt.ylabel(list(invar.keys())[1])
plt.colorbar()
plt.title(f"{k}_{tag} (slice at {list(invar.keys())[2]}={slice_val:.2f})")
plt.tight_layout()
fs.append((f, f"{k}_slice_at_{list(invar.keys())[2]}_{slice_val:.2f}"))
return fs
return []
@modulus.sym.main(config_path="conf", config_name="config")
def run(cfg: ModulusConfig) -> None:
# make list of nodes to unroll graph on
we = WaveEquation1D(c=1.0)
wave_net = instantiate_arch(
input_keys=[Key("x"), Key("t"), Key("a")],
output_keys=[Key("u")],
cfg=cfg.arch.fully_connected
)
nodes = we.make_nodes() + [wave_net.make_node(name="wave_network")]
# add constraints to solver
# make geometry
x, t_symbol = Symbol("x"), Symbol("t")
a_symbol = Symbol("a")
L = float(np.pi)
geo = Line1D(0, L)
time_range = (0, 2 * L)
amp_range = (1, 3)
time_param = {t_symbol: time_range}
param = Parameterization(
{t_symbol: time_range, a_symbol: amp_range}
)
param_initial = Parameterization(
{t_symbol: 0.0, a_symbol: amp_range}
)
# make domain
domain = Domain()
# initial condition
IC = PointwiseInteriorConstraint(
nodes=nodes,
geometry=geo,
outvar={"u": a_symbol * sin(x), "u__t": a_symbol * sin(x)},
batch_size=cfg.batch_size.IC,
lambda_weighting={"u": 1.0, "u__t": 1.0},
parameterization=param_initial,
)
domain.add_constraint(IC, "IC")
# boundary condition
BC = PointwiseBoundaryConstraint(
nodes=nodes,
geometry=geo,
outvar={"u": 0},
batch_size=cfg.batch_size.BC,
parameterization=param,
)
domain.add_constraint(BC, "BC")
# interior
interior = PointwiseInteriorConstraint(
nodes=nodes,
geometry=geo,
outvar={"wave_equation": 0},
batch_size=cfg.batch_size.interior,
parameterization=param,
)
domain.add_constraint(interior, "interior")
deltaT = 0.01
deltaX = 0.01
deltaa = .5
x = np.arange(0, L, deltaX)
t = np.arange(0, 2 * L, deltaT)
a_numpy = np.arange(1, 3, deltaa)
# X, T, A = [each.ravel(order='F') for each in np.meshgrid(x, t, a)]
X, T, A = np.meshgrid(x, t, a_numpy)
X = np.expand_dims(X.flatten(), axis=-1)
T = np.expand_dims(T.flatten(), axis=-1)
A = np.expand_dims(A.flatten(), axis=-1)
u = A * (np.sin(X) * (np.cos(T) + np.sin(T)))
# a = A * np.sin(X) * (np.cos(T) + np.sin(T))
invar_numpy = {"x": X, "t": T, "a": A}
outvar_numpy = {"u": u}
validator = PointwiseValidator(
nodes=nodes,
invar=invar_numpy,
true_outvar=outvar_numpy,
batch_size=1024,
plotter=ValidatorPlotter3D()
)
domain.add_validator(validator)
# make solver
slv = Solver(cfg, domain)
# start solver
slv.solve()
if __name__ == "__main__":
run()