Hi,
I would like to have some insights on the optimizers compatible with modulus-sym and CUDA constraints
-
The loss is quite oscilating and not decreasing much
-
CUDA shows out of memory but gpu has memeory vacant
error:
ret = run_job(
[12:24:40] - JIT using the NVFuser TorchScript backend
[12:24:40] - JitManager: {'_enabled': True, '_arch_mode': <JitArchMode.ONLY_ACTIVATION: 1>, '_use_nvfuser': True, '_autograd_nodes': False}
[12:24:40] - GraphManager: {'_func_arch': False, '_debug': False, '_func_arch_allow_partial_hessian': True}
[12:26:04] - Installed PyTorch version 2.0.1+cu117 is not TorchScript supported in Modulus. Version 1.14.0a0+410ce96 is officially supported.
[12:26:04] - attempting to restore from: outputs/double_notch_model
[12:26:04] - Success loading optimizer: outputs/double_notch_model/optim_checkpoint.0.pth
[12:26:04] - Success loading model: outputs/double_notch_model/damage_net.0.pth
[12:26:12] - [step: 3000] record constraint batch time: 7.138e-01s
[12:27:01] - [step: 3000] record inferencers time: 4.913e+01s
[12:27:01] - [step: 3000] saved checkpoint to outputs/double_notch_model
[12:27:01] - [step: 3000] loss: 3.528e+22
[12:27:21] - Attempting cuda graph building, this may take a bit...
[12:28:42] - [step: 3100] loss: 3.505e+22, time/iteration: 1.009e+03 ms
[12:30:21] - [step: 3200] loss: 3.589e+22, time/iteration: 9.955e+02 ms
[12:32:00] - [step: 3300] loss: 3.563e+22, time/iteration: 9.932e+02 ms
[12:33:40] - [step: 3400] loss: 3.525e+22, time/iteration: 9.971e+02 ms
[12:35:21] - [step: 3500] loss: 3.546e+22, time/iteration: 1.004e+03 ms
[12:37:00] - [step: 3600] loss: 3.555e+22, time/iteration: 9.981e+02 ms
[12:38:40] - [step: 3700] loss: 3.548e+22, time/iteration: 9.963e+02 ms
[12:40:20] - [step: 3800] loss: 3.704e+22, time/iteration: 9.963e+02 ms
[12:41:59] - [step: 3900] loss: 3.513e+22, time/iteration: 9.949e+02 ms
/usr/local/lib/python3.8/site-packages/modulus/sym/eq/derivatives.py:99: UserWarning: FALLBACK path has been taken inside: runCudaFusionGroup. This is an indication that codegen Failed for some reason.
To debug try disable codegen fallback path via setting the env variable `export PYTORCH_NVFUSER_DISABLE=fallback`
(Triggered internally at ../third_party/nvfuser/csrc/manager.cpp:335.)
grad = gradient(var, grad_var)
Error executing job with overrides: []
Traceback (most recent call last):
File "/content/drive/MyDrive/luh/DM1/double_notch_model.py", line 332, in run
slv.solve()
File "/usr/local/lib/python3.8/site-packages/modulus/sym/solver/solver.py", line 173, in solve
self._train_loop(sigterm_handler)
File "/usr/local/lib/python3.8/site-packages/modulus/sym/trainer.py", line 607, in _train_loop
self._record_constraints()
File "/usr/local/lib/python3.8/site-packages/modulus/sym/trainer.py", line 289, in _record_constraints
self.record_constraints()
File "/usr/local/lib/python3.8/site-packages/modulus/sym/solver/solver.py", line 130, in record_constraints
self.domain.rec_constraints(self.network_dir)
File "/usr/local/lib/python3.8/site-packages/modulus/sym/domain/domain.py", line 59, in rec_constraints
constraint.save_batch(constraint_data_dir + key)
File "/usr/local/lib/python3.8/site-packages/modulus/sym/domain/constraint/continuous.py", line 74, in save_batch
pred_outvar = modl(invar)
File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/site-packages/modulus/sym/graph.py", line 234, in forward
outvar.update(e(outvar))
File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/site-packages/modulus/sym/eq/derivatives.py", line 99, in forward
grad = gradient(var, grad_var)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "/usr/local/lib/python3.8/site-packages/modulus/sym/eq/derivatives.py", line 38, in gradient
"""
grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y, device=y.device)]
grad = torch.autograd.grad(
~~~~~~~~~~~~~~~~~~~ <--- HERE
[
y,
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "<string>", line 155, in fallback_cuda_fuser
def neg(self):
def backward(grad_output):
return grad_output.neg()
~~~~~~~~~~~~~~~ <--- HERE
return torch.neg(self), backward
RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 14.75 GiB total capacity; 4.10 GiB already allocated; 2.81 MiB free; 13.76 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Scripts:
- config.yaml
defaults :
- modulus_default
- arch:
- fully_connected
- scheduler: tf_exponential_lr
- optimizer: radam
- loss: sum
- _self_
arch:
fully_connected:
layer_size: 512
nr_layers: 8
scheduler:
decay_rate: 0.95
decay_steps: 15000
training:
rec_validation_freq: 1000
rec_results_freq : 1000
rec_inference_freq: 1000
rec_constraint_freq: 1000
max_steps : 100000
batch_size:
panel_left: 1100
panel_right: 1100
cut: 1000
panel_bottom: 400
panel_top: 400
interior: 10000
- modulus running script
# libraries
from sympy import Symbol, Function, exp, Number
from modulus.sym.eq.pde import PDE
from modulus.sym.node import Node
from sympy import Symbol, Eq
import numpy as np
import modulus.sym
from modulus.sym.hydra import instantiate_arch, ModulusConfig
from modulus.sym.models.layers import Activation
from modulus.sym.solver import Solver
from modulus.sym.domain import Domain
from modulus.sym.geometry.primitives_2d import Rectangle, Circle
from modulus.sym.domain.constraint import (
PointwiseBoundaryConstraint,
PointwiseInteriorConstraint,
)
from modulus.sym.domain.validator import PointwiseValidator
from modulus.sym.domain.inferencer import PointwiseInferencer
from modulus.sym.key import Key
from modulus.sym.node import Node
from modulus.sym.geometry.adf import ADF
E = 200e9
nu = 0.33
class damage_model(PDE):
name = "damage_model"
def __init__(self, E, nu, rho = 1e4, beta=1/8100, r1 = 1e4, r2 = 1e4, dim =2, time =True):
self.sim = dim
self.time = time
# space coordinates
x = Symbol("x")
y = Symbol("y")
# plane normals
normal_x, normal_y= (
Symbol("normal_x"),
Symbol("normal_y"),
)
# time coordinate
t = Symbol("t")
# functionality
input_variables = {"x": x, "y": y, "t": t}
# material properties
if E is None:
E = Function("E", positive = True)(*input_variables)
if nu is None:
nu = Function("nu", positive = True)(*input_variables)
# LE terms
lambda_ = nu * E / ((1 + nu) * (1 - 2 * nu))
mu = E / (2 * (1 + nu))
# displacement fields
u = Function("u")(*input_variables)
v = Function("v")(*input_variables)
# damage field
d = Function("d", nonnegative = True)(*input_variables)
# damage function:
f = 1 - exp(-d)
# strains
e_xx = u.diff(x)
e_yy = v.diff(y)
e_xy = 0.5*(v.diff(x) + u.diff(y))
# stress fields
w_z = -lambda_ / (lambda_ + 2 * mu) * (u.diff(x) + v.diff(y))
sigma_xx = lambda_ * (u.diff(x) + v.diff(y) + w_z) + 2 * mu * u.diff(x)
sigma_yy = lambda_ * (u.diff(x) + v.diff(y) + w_z) + 2 * mu * v.diff(y)
sigma_xy = mu * (u.diff(y) + v.diff(x))
# set equations
self.equations = {}
# stress equations as per Linear Elasticity
self.equations["LE_sigma_xx"] = sigma_xx
self.equations["LE_sigma_yy"] = sigma_yy
self.equations["LE_sigma_xy"] = sigma_xy
# Equations of equilibrium
self.equations["eq_x"] = rho * ((u.diff(t)).diff(t)) - (sigma_xx.diff(x) + sigma_xy.diff(y))
self.equations["eq_y"] = rho * ((v.diff(t)).diff(t)) - (sigma_xy.diff(x) + sigma_yy.diff(y))
# Traction equations
self.equations["tr_x"] = normal_x * sigma_xx + normal_y * sigma_xy
self.equations["tr_y"] = normal_x * sigma_xy + normal_y * sigma_yy
# LE energy: phi_{0}
self.equations["phi_0"] = 0.5*(sigma_xx*e_xx + sigma_yy*e_yy + 2*sigma_xy*e_xy)
# damge function
self.equations["f"] = f
# distorsion energy
self.equations["phi"] = f*self.equations["phi_0"] + 0.5*beta*((f.diff(x))**2 + (f.diff(y))**2)
# /delta^{diss}
self.equations["diss"] = r1*(((f.diff(t))**2)**0.5) + 0.5*r2*((f.diff(t))**2)
# p
self.equations["p"] = -1*self.equations["phi_0"] + beta*(f.diff(x, 2) + f.diff(y, 2))
# p_hat
self.equations["p_hat"] = self.equations["p"] - r1
# p_plus
self.equations["p_plus"] = 0.5*(self.equations["p_hat"] + (self.equations["p_hat"]**2)**0.5)
# DM
self.equations["DM"] = r2*(f.diff(t)) + self.equations["p_plus"]
# running the model
@modulus.sym.main(config_path="config", config_name="conf")
def run(cfg: ModulusConfig) -> None:
equations = damage_model(E = E, nu = nu)
damage_net = instantiate_arch(
input_keys = [
Key("x"),
Key("y"),
Key("t"),
],
output_keys = [
Key("u"),
Key("v"),
Key("d"),
],
cfg = cfg.arch.fully_connected,
activation_fn=Activation.TANH,
)
nodes = equations.make_nodes() + [damage_net.make_node(name="damage_net")]
panel_left_bottom = (-.02, -.055)
panel_right_top = (.02, .055)
circle_radius = 0.005 # (m)
circle1_origin = (-.02, -.01) # (m, m)
circle2_origin = (.02, .01) # (m, m)
c1 = Circle(circle1_origin, circle_radius)
c2 = Circle(circle2_origin, circle_radius)
panel = Rectangle(panel_left_bottom, panel_right_top)
cut = c1 + c2
# panel_origin = (0, 0)
# panel_left_bottom = (-2, -4)
# panel_right_top = (2, 4)
# circle_radius = 1 # (m)
# c = Circle(panel_origin, circle_radius)
# panel = Rectangle(panel_left_bottom, panel_right_top)
# cut = c
geo = panel - cut
x = Symbol("x")
y = Symbol("y")
t_symbol = Symbol("t")
bound_x = (panel_left_bottom[0], panel_right_top[0])
bound_y = (panel_left_bottom[1], panel_right_top[1])
bound_t = (0.0, 1.0)
time_range = {t_symbol: bound_t}
domain = Domain()
left = PointwiseBoundaryConstraint(
nodes = nodes,
geometry = geo,
outvar = {
"tr_x": 0.0,
"tr_y": 0.0,
},
batch_size = cfg.batch_size.panel_left,
criteria = Eq(x, panel_left_bottom[0]),
parameterization=time_range,
lambda_weighting={
"tr_x": 1.0,
"tr_y": 1.0,
},
quasirandom=True,
batch_per_epoch=500,
)
domain.add_constraint(left, "BC_left")
right = PointwiseBoundaryConstraint(
nodes = nodes,
geometry = geo,
outvar = {
"tr_x": 0.0,
"tr_y": 0.0,
},
batch_size = cfg.batch_size.panel_right,
criteria = Eq(x, panel_right_top[0]),
parameterization=time_range,
lambda_weighting={
"tr_x": 1.0,
"tr_y": 1.0,
},
quasirandom=True,
batch_per_epoch=500,
)
domain.add_constraint(right, "BC_right")
cutting = PointwiseBoundaryConstraint(
nodes = nodes,
geometry = cut,
outvar = {
"tr_x": 0.0,
"tr_y": 0.0,
},
batch_size = cfg.batch_size.panel_right,
parameterization=time_range,
lambda_weighting={
"tr_x": 10.0,
"tr_y": 10.0,
},
quasirandom=True,
batch_per_epoch=500,
)
domain.add_constraint(cutting, "BC_cutting")
panel_bottom = PointwiseBoundaryConstraint(
nodes=nodes,
geometry=geo,
outvar = {
"v": 0.0,
"d": 0.0
},
batch_size = cfg.batch_size.panel_bottom,
criteria = Eq(y, panel_left_bottom[1]),
parameterization=time_range,
lambda_weighting={
"v": 10.0,
"d": 10.0,
},
quasirandom=True,
batch_per_epoch=500,
)
domain.add_constraint(panel_bottom, "BC_bottom")
panel_top = PointwiseBoundaryConstraint(
nodes=nodes,
geometry=geo,
outvar={"v": 0.01*t_symbol},
batch_size = cfg.batch_size.panel_top,
criteria = Eq(y, panel_right_top[1]),
parameterization=time_range,
lambda_weighting={
"v": 10.0,
},
quasirandom=True,
batch_per_epoch=500,
)
domain.add_constraint(panel_top, "BC_top")
IC = PointwiseInteriorConstraint(
nodes = nodes,
geometry = geo,
outvar={"u": 0.0,
"v": 0.0,
"d": 0.0,
},
batch_size=cfg.batch_size.interior,
lambda_weighting={
"u": 1.0,
"v": 1.0,
"d": 1.0,
},
quasirandom=True,
parameterization=time_range,
batch_per_epoch=500,
)
domain.add_constraint(IC, "IC")
interior = PointwiseInteriorConstraint(
nodes=nodes,
geometry=geo,
outvar={
"eq_x": 0.0,
"eq_y": 0.0,
"DM": 0.0
},
batch_size=cfg.batch_size.interior,
bounds={x: bound_x, y: bound_y,},
lambda_weighting={
"eq_x": 10.0,
"eq_y": 10.0,
"DM": 10.0, # Symbol("sdf")
},
quasirandom=True,
parameterization=time_range,
)
domain.add_constraint(interior, "Interior")
for i, specific_time in enumerate(np.linspace(0, bound_t[1], 10)):
invar_numpy = geo.sample_interior(
1000000,
bounds={x: bound_x, y: bound_y},
parameterization={t_symbol: specific_time},
)
grid_inference = PointwiseInferencer(
nodes=nodes,
invar=invar_numpy,
output_names=[
"u",
"v",
"d",
# "f",
# "phi",
],
batch_size=4096,
)
domain.add_inferencer(grid_inference, name="time_slice_" + str(i+1).zfill(3))
# make solver
slv = Solver(cfg, domain)
# start solver
slv.solve()
if __name__ == "__main__":
run()
Could anyone please explain these issues, and how can i correct them? please let me know if you need complete successful training output as well
Thank you