Issues Saving and Using Fine-Tuned Model for Inference in stokes_mgn (PhysicsNemo)

Hello,

I’m currently working on the stokes_mgn example from the NVIDIA PhysicsNemo CFD repository. I was able to successfully train the model as described in the README and also perform physics-informed fine-tuning.

However, I noticed that after fine-tuning, the example code does not save the fine-tuned model. To address this, I made modifications to save the model using both the standard PyTorch save method as well as the PhysicsNemo/Modulus save_checkpoint utility.

I also wrote a separate inference script to run predictions on different .vtp files using this saved fine-tuned model. However, the predictions I get are not as accurate as expected — they are significantly worse than those seen during or after fine-tuning.

These are the modifications I made in code of pi_fine_tuning_gnn.py

after finetuning gets completed I save the model in two ways

 logger.info("Physics-informed fine-tuning training completed!")

    # Save the model state
    model_save_path = os.path.join(to_absolute_path(cfg.results_dir), "pi_fine_tuner.pt")
    torch.save(pi_fine_tuner.model.state_dict(), model_save_path)
    logger.info(f"Model saved at {model_save_path}")    

    save_checkpoint(
        to_absolute_path(cfg.results_dir),
        models=pi_fine_tuner.model,
        optimizer=pi_fine_tuner.optimizer,
        scheduler=pi_fine_tuner.scheduler,
        epoch=cfg.pi_iters,
    )

These are the files I created for inferencing

1.For PyTorch method saved model

import os
import numpy as np
import torch
import pyvista as pv
from omegaconf import OmegaConf
from physicsnemo.models.meshgraphnet import MeshGraphNet
from utils import get_dataset


def simple_inference(vtp_file_path, model_path, config_path, output_path=None):
    """
    Simple inference function that takes a VTP file and runs the fine-tuned model
    
    Parameters:
    -----------
    vtp_file_path : str
        Path to input VTP file
    model_path : str
        Path to fine-tuned model (.pt file)
    config_path : str
        Path to configuration file
    output_path : str, optional
        Path to save output file
    """
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load configuration
    cfg = OmegaConf.load(config_path)
    
    # Load the fine-tuned model
    print("Loading fine-tuned model...")
    model = MeshGraphNet(
        cfg.input_dim_nodes + 128,  # additional 128 node features from fourier features
        cfg.input_dim_edges,
        cfg.output_dim,
        aggregation=cfg.aggregation,
        hidden_dim_node_encoder=cfg.hidden_dim_node_encoder,
        hidden_dim_edge_encoder=cfg.hidden_dim_edge_encoder,
        hidden_dim_node_decoder=cfg.hidden_dim_node_decoder,
    ).to(device)
    
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    
    # Process VTP file
    print("Processing VTP file...")
    try:
        (
            ref_u, ref_v, ref_p,
            gnn_u, gnn_v, gnn_p,
            coords, inflow_coords, outflow_coords, wall_coords, polygon_coords,
            nu, dgl_graph
        ) = get_dataset(vtp_file_path, return_graph=True)
        
        print(f"Successfully processed VTP file with {dgl_graph.number_of_nodes()} nodes")
        print(f"Viscosity (nu): {nu}")
        
    except Exception as e:
        print(f"Error processing VTP file: {e}")
        return None
    
    # Run inference
    print("Running inference...")
    dgl_graph = dgl_graph.to(device)
    
    with torch.no_grad():
        output = model(dgl_graph.ndata["x"], dgl_graph.edata["x"], dgl_graph)
        
        pred_u = output[:, 0:1].detach().cpu().numpy()
        pred_v = output[:, 1:2].detach().cpu().numpy()
        pred_p = output[:, 2:3].detach().cpu().numpy()
    
    # Calculate errors
    def relative_l2_error(pred, ref):
        return np.linalg.norm(pred - ref) / np.linalg.norm(ref) * 100
    
    print("\nResults:")
    print("Original GNN vs Reference:")
    print(f"  U error: {relative_l2_error(gnn_u, ref_u):.3f}%")
    print(f"  V error: {relative_l2_error(gnn_v, ref_v):.3f}%")
    print(f"  P error: {relative_l2_error(gnn_p, ref_p):.3f}%")
    
    print("\nFine-tuned Model vs Reference:")
    print(f"  U error: {relative_l2_error(pred_u, ref_u):.3f}%")
    print(f"  V error: {relative_l2_error(pred_v, ref_v):.3f}%")
    print(f"  P error: {relative_l2_error(pred_p, ref_p):.3f}%")
    
    # Save results
    if output_path is None:
        base_name = os.path.splitext(vtp_file_path)[0]
        output_path = f"{base_name}_fine_tuned.vtp"
    
    print(f"\nSaving results to: {output_path}")
    polydata = pv.read(vtp_file_path)
    polydata["fine_tuned_u"] = pred_u.flatten()
    polydata["fine_tuned_v"] = pred_v.flatten()
    polydata["fine_tuned_p"] = pred_p.flatten()
    polydata.save(output_path)
    
    print("Inference completed successfully!")
    return pred_u, pred_v, pred_p


# Example usage
if __name__ == "__main__":
    # Update these paths according to your setup
    vtp_file = "path/to/your/input.vtp"  # Your input VTP file
    model_file = "path/to/your/pi_fine_tuner.pt"  # Your fine-tuned model
    config_file = "path/to/your/config.yaml"  # Your configuration file
    
    # Run inference
    results = simple_inference(vtp_file, model_file, config_file)

  1. For checkpoint saved method :
import os
import numpy as np
import torch
import pyvista as pv
import dgl
from physicsnemo.models.meshgraphnet import MeshGraphNet
from physicsnemo.launch.utils import load_checkpoint
from hydra.utils import to_absolute_path
from utils import get_dataset


def load_fine_tuned_model(results_dir, device):
    """Load the fine-tuned model with hardcoded config values"""
    # Hardcoded config values from config.yaml
    cfg = {
        "input_dim_nodes": 7,
        "input_dim_edges": 3,
        "output_dim": 3,
        "hidden_dim_node_encoder": 256,
        "hidden_dim_edge_encoder": 256,
        "hidden_dim_node_decoder": 256,
        "aggregation": "sum"
    }

    model = MeshGraphNet(
        cfg["input_dim_nodes"] + 128,  # +128 due to Fourier features
        cfg["input_dim_edges"],
        cfg["output_dim"],
        aggregation=cfg["aggregation"],
        hidden_dim_node_encoder=cfg["hidden_dim_node_encoder"],
        hidden_dim_edge_encoder=cfg["hidden_dim_edge_encoder"],
        hidden_dim_node_decoder=cfg["hidden_dim_node_decoder"],
    ).to(device)

    # Load model checkpoint
    _ = load_checkpoint(
        to_absolute_path(results_dir),
        models=model,
        device=device,
    )

    model.eval()
    return model


def process_vtp_file(vtp_file_path):
    try:
        (
            ref_u,
            ref_v,
            ref_p,
            gnn_u,
            gnn_v,
            gnn_p,
            coords,
            inflow_coords,
            outflow_coords,
            wall_coords,
            polygon_coords,
            nu,
            dgl_graph,
        ) = get_dataset(vtp_file_path, return_graph=True)

        return dgl_graph, coords, nu
    except Exception as e:
        print(f"Error processing VTP file: {e}")
        raise


def run_inference(model, dgl_graph, device):
    dgl_graph = dgl_graph.to(device)

    with torch.no_grad():
        output = model(dgl_graph.ndata["x"], dgl_graph.edata["x"], dgl_graph)
        pred_u = output[:, 0:1].detach().cpu().numpy()
        pred_v = output[:, 1:2].detach().cpu().numpy()
        pred_p = output[:, 2:3].detach().cpu().numpy()

    return pred_u, pred_v, pred_p


def save_results(input_vtp_path, output_path, pred_u, pred_v, pred_p):
    polydata = pv.read(input_vtp_path)
    polydata["fine_tuned_u"] = pred_u.flatten()
    polydata["fine_tuned_v"] = pred_v.flatten()
    polydata["fine_tuned_p"] = pred_p.flatten()
    polydata.save(output_path)
    print(f"Results saved to: {output_path}")


def compare_results(input_vtp_path, pred_u, pred_v, pred_p):
    polydata = pv.read(input_vtp_path)
    ref_u = polydata.point_data["u"].reshape(-1, 1)
    ref_v = polydata.point_data["v"].reshape(-1, 1)
    ref_p = polydata.point_data["p"].reshape(-1, 1)
    gnn_u = polydata.point_data["pred_u"].reshape(-1, 1)
    gnn_v = polydata.point_data["pred_v"].reshape(-1, 1)
    gnn_p = polydata.point_data["pred_p"].reshape(-1, 1)

    def relative_l2_error(pred, ref):
        return np.linalg.norm(pred - ref) / np.linalg.norm(ref) * 100

    print("\nOriginal GNN vs Reference:")
    print(f"  U error: {relative_l2_error(gnn_u, ref_u):.3f}%")
    print(f"  V error: {relative_l2_error(gnn_v, ref_v):.3f}%")
    print(f"  P error: {relative_l2_error(gnn_p, ref_p):.3f}%")

    print("\nFine-tuned Model vs Reference:")
    print(f"  U error: {relative_l2_error(pred_u, ref_u):.3f}%")
    print(f"  V error: {relative_l2_error(pred_v, ref_v):.3f}%")
    print(f"  P error: {relative_l2_error(pred_p, ref_p):.3f}%")

    print("\nFine-tuned vs Original GNN:")
    print(f"  U difference: {relative_l2_error(pred_u, gnn_u):.3f}%")
    print(f"  V difference: {relative_l2_error(pred_v, gnn_v):.3f}%")
    print(f"  P difference: {relative_l2_error(pred_p, gnn_p):.3f}%")


def main():
    # Hardcoded inputs
    vtp_file = "graph_5.vtp"
    results_dir = "finetunecheckpoints"
    output_path = vtp_file.replace(".vtp", "_fine_tuned.vtp")
    compare = True

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    print("Loading fine-tuned model...")
    model = load_fine_tuned_model(results_dir, device)

    print("Processing VTP file...")
    try:
        dgl_graph, coords, nu = process_vtp_file(vtp_file)
        print(f"Graph has {dgl_graph.number_of_nodes()} nodes and {dgl_graph.number_of_edges()} edges")
        print(f"Viscosity (nu): {nu}")
    except Exception as e:
        print(f"Failed to process VTP file: {e}")
        return

    print("Running inference...")
    pred_u, pred_v, pred_p = run_inference(model, dgl_graph, device)

    print("Saving results...")
    save_results(vtp_file, output_path, pred_u, pred_v, pred_p)

    if compare:
        print("Comparing results...")
        compare_results(vtp_file, pred_u, pred_v, pred_p)

    print("\n✅ Inference completed successfully!")
    print(f"👉 Fine-tuned predictions saved in: {output_path}")


if __name__ == "__main__":
    main()

Questions:

  1. What is the correct way to save and load a fine-tuned model in the PhysicsNemo framework to preserve all necessary state for inference?
  2. What is the role of the .mdlus file generated during checkpoint saving? Should it be used during inference, and how?
  3. Could someone take a look at the modifications I made to save the model and my inference script, and suggest improvements or corrections?

Thanks for support !!

1 Like

Hello @abouzar.ghasemi,

I’ve tried saving the fine-tuned model at every 100 iteration like in code written in train.py, but the inference results still don’t match expectations. Below is how I’m saving the model during fine-tuning and the method I’m using for inference.

Saving Finetuned Model Code :

import os
import time

import hydra
import numpy as np
import torch
import wandb
from hydra.utils import to_absolute_path
from omegaconf import DictConfig

try:
    import apex
except:
    pass

try:
    import pyvista as pv
except:
    raise ImportError(
        "Stokes Dataset requires the pyvista library. Install with "
        + "pip install pyvista"
    )

from collections import OrderedDict
from typing import Dict, Optional

from physicsnemo.launch.logging import (
    PythonLogger,
    RankZeroLoggingWrapper,
)
from physicsnemo.launch.logging.wandb import initialize_wandb
from physicsnemo.models.meshgraphnet import MeshGraphNet
from physicsnemo.sym.eq.pde import PDE
from physicsnemo.sym.eq.phy_informer import PhysicsInformer
from physicsnemo.sym.eq.spatial_grads.spatial_grads import compute_connectivity_tensor
from sympy import Function, Number, Symbol
from physicsnemo.launch.utils import save_checkpoint

from utils import get_dataset, relative_lp_error


class Stokes(PDE):
    """Incompressible Stokes flow"""

    def __init__(self, nu, dim=3):
        # set params
        self.dim = dim

        # coordinates
        x, y, z = Symbol("x"), Symbol("y"), Symbol("z")

        # make input variables
        input_variables = {"x": x, "y": y, "z": z}
        if self.dim == 2:
            input_variables.pop("z")

        # velocity componets
        u = Function("u")(*input_variables)
        v = Function("v")(*input_variables)
        if self.dim == 3:
            w = Function("w")(*input_variables)
        else:
            w = Number(0)

        # pressure
        p = Function("p")(*input_variables)

        # kinematic viscosity
        if isinstance(nu, str):
            nu = Function(nu)(*input_variables)
        elif isinstance(nu, (float, int)):
            nu = Number(nu)

        # set equations
        self.equations = {}
        self.equations["continuity"] = u.diff(x) + v.diff(y) + w.diff(z)
        self.equations["momentum_x"] = +p.diff(x) - nu * (
            u.diff(x).diff(x) + u.diff(y).diff(y) + u.diff(z).diff(z)
        )
        self.equations["momentum_y"] = +p.diff(y) - nu * (
            v.diff(x).diff(x) + v.diff(y).diff(y) + v.diff(z).diff(z)
        )
        self.equations["momentum_z"] = +p.diff(z) - nu * (
            w.diff(x).diff(x) + w.diff(y).diff(y) + w.diff(z).diff(z)
        )

        if self.dim == 2:
            self.equations.pop("momentum_z")


class PhysicsInformedFineTuner:
    """
    Class to define all the physics informed utils and inference.
    """

    def __init__(
        self,
        cfg,
        device,
        gnn_u,
        gnn_v,
        gnn_p,
        coords,
        coords_inflow,
        coords_noslip,
        nu,
        ref_u,
        ref_v,
        ref_p,
        dgl_graph,
    ):
        super().__init__()

        self.device = device
        self.nu = nu
        self.dgl_graph = dgl_graph.to(self.device)
        edge_tensor = torch.stack(
            [dgl_graph.edges()[0], dgl_graph.edges()[1]], dim=1
        ).to(self.device)
        self.connectivity_tensor = compute_connectivity_tensor(
            dgl_graph.nodes(), edge_tensor
        )
        self.connectivity_tensor = self.connectivity_tensor.to(self.device)

        self.ref_u = torch.tensor(ref_u).float().to(self.device)
        self.ref_v = torch.tensor(ref_v).float().to(self.device)
        self.ref_p = torch.tensor(ref_p).float().to(self.device)

        self.gnn_u = torch.tensor(gnn_u).float().to(self.device)
        self.gnn_v = torch.tensor(gnn_v).float().to(self.device)
        self.gnn_p = torch.tensor(gnn_p).float().to(self.device)

        self.coords = torch.tensor(coords, requires_grad=True).float().to(self.device)
        self.coords_inflow = (
            torch.tensor(coords_inflow, requires_grad=True).float().to(self.device)
        )
        self.coords_noslip = (
            torch.tensor(coords_noslip, requires_grad=True).float().to(self.device)
        )

        self.model = MeshGraphNet(
            cfg.input_dim_nodes
            + 128,  # additional 128 node features from fourier features
            cfg.input_dim_edges,
            cfg.output_dim,
            aggregation=cfg.aggregation,
            hidden_dim_node_encoder=cfg.hidden_dim_node_encoder,
            hidden_dim_edge_encoder=cfg.hidden_dim_edge_encoder,
            hidden_dim_node_decoder=cfg.hidden_dim_node_decoder,
        ).to(self.device)

        self.node_pde = Stokes(nu=self.nu, dim=2)

        # note: this example uses the PhysicsInformer class from PhysicsNeMo Sym to
        # construct the computational graph. This allows you to leverage PhysicsNeMo Sym's
        # optimized derivative backend to compute the derivatives, along with other
        # benefits like symbolic definition of PDEs and leveraging the PDEs from PhysicsNeMo
        # Sym's PDE module.

        self.phy_informer = PhysicsInformer(
            required_outputs=["continuity", "momentum_x", "momentum_y"],
            equations=self.node_pde,
            grad_method="least_squares",
            device=self.device,
            compute_connectivity=False,
        )

        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=cfg.pi_lr,
            fused=True if torch.cuda.is_available() else False,
        )

        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer, gamma=0.99935
        )

    def parabolic_inflow(self, y, U_max=0.3):
        u = 4 * U_max * y * (0.4 - y) / (0.4**2)
        v = torch.zeros_like(y)
        return u, v

    def loss(self):
        out = self.model(
            self.dgl_graph.ndata["x"], self.dgl_graph.edata["x"], self.dgl_graph
        )

        # inflow points
        mask_inflow = (
            self.dgl_graph.ndata["marker"]
            == torch.tensor([0, 1, 0, 0, 0]).to(self.device)
        ).all(dim=1)
        results_inflow = {
            k: out[:, i : i + 1][mask_inflow] for i, k in enumerate(["u", "v", "p"])
        }
        pred_u_in, pred_v_in = results_inflow["u"], results_inflow["v"]

        # no-slip points
        mask_1 = (
            self.dgl_graph.ndata["marker"]
            == torch.tensor([0, 0, 0, 1, 0]).to(self.device)
        ).all(dim=1)
        mask_2 = (
            self.dgl_graph.ndata["marker"]
            == torch.tensor([0, 0, 0, 0, 1]).to(self.device)
        ).all(dim=1)
        mask_noslip = torch.logical_or(mask_1, mask_2)
        results_noslip = {
            k: out[:, i : i + 1][mask_noslip] for i, k in enumerate(["u", "v", "p"])
        }
        pred_u_noslip, pred_v_noslip = results_noslip["u"], results_noslip["v"]

        # interior points
        mask_int = (
            self.dgl_graph.ndata["marker"]
            == torch.tensor([1, 0, 0, 0, 0]).to(self.device)
        ).all(dim=1)
        model_out = {
            k: out[:, i : i + 1][mask_int] for i, k in enumerate(["u", "v", "p"])
        }
        results_int = self.phy_informer.forward(
            {
                "coordinates": self.dgl_graph.ndata["pos"][:, 0:2],
                "u": out[:, 0:1],
                "v": out[:, 1:2],
                "p": out[:, 2:3],
                "connectivity_tensor": self.connectivity_tensor,
            }
        )
        pred_mom_u, pred_mom_v, pred_cont = (
            results_int["momentum_x"][mask_int],
            results_int["momentum_y"][mask_int],
            results_int["continuity"][mask_int],
        )
        pred_u, pred_v, pred_p = model_out["u"], model_out["v"], model_out["p"]

        u_in, v_in = self.parabolic_inflow(self.coords_inflow[:, 1:2])

        # Compute losses
        # data loss
        loss_u = torch.mean((self.gnn_u[mask_int] - pred_u) ** 2)
        loss_v = torch.mean((self.gnn_v[mask_int] - pred_v) ** 2)
        loss_p = torch.mean((self.gnn_p[mask_int] - pred_p) ** 2)

        # inflow boundary condition loss
        loss_u_in = torch.mean((u_in - pred_u_in) ** 2)
        loss_v_in = torch.mean((v_in - pred_v_in) ** 2)

        # noslip boundary condition loss
        loss_u_noslip = torch.mean(pred_u_noslip**2)
        loss_v_noslip = torch.mean(pred_v_noslip**2)

        # pde loss
        loss_mom_u = torch.mean(pred_mom_u**2)
        loss_mom_v = torch.mean(pred_mom_v**2)
        loss_cont = torch.mean(pred_cont**2)

        return (
            loss_u,
            loss_v,
            loss_p,
            loss_u_in,
            loss_v_in,
            loss_u_noslip,
            loss_v_noslip,
            loss_mom_u,
            loss_mom_v,
            loss_cont,
        )

    def train(self):
        """PINN based fine-tuning"""
        (
            loss_u,
            loss_v,
            loss_p,
            loss_u_in,
            loss_v_in,
            loss_u_noslip,
            loss_v_noslip,
            loss_mom_u,
            loss_mom_v,
            loss_cont,
        ) = self.loss()

        # Add custom weights to the different losses. The weights are chosen after
        # investigating the relative magnitudes of individual losses and their
        # convergence behavior.
        loss = (
            1 * loss_u
            + 1 * loss_v
            + 1 * loss_p
            + 10 * loss_u_in
            + 10 * loss_v_in
            + 10 * loss_u_noslip
            + 10 * loss_v_noslip
            + 1 * loss_mom_u
            + 1 * loss_mom_v
            + 10 * loss_cont
        )
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.scheduler.step()

        return (
            loss_u,
            loss_v,
            loss_p,
            loss_u_in,
            loss_v_in,
            loss_u_noslip,
            loss_v_noslip,
            loss_mom_u,
            loss_mom_v,
            loss_cont,
            self.optimizer.param_groups[0]["lr"],
        )

    * def save_checkpoint1(self, iteration: int, checkpoint_dir: str = "./workspace/checkpoints"):*

*        os.makedirs(checkpoint_dir, exist_ok=True)*
*        checkpoint_path = os.path.join(checkpoint_dir, f"finetuner_iter_{iteration}.pt")*
*        save_dict = {*
*            'iteration': iteration,*
*            'model_state_dict': self.model.state_dict(),*
*            'optimizer_state_dict': self.optimizer.state_dict(),*
*            'scheduler_state_dict': self.scheduler.state_dict(),*
*        }*
*        torch.save(save_dict, checkpoint_path)*

*        print(f"Checkpoint saved at iteration {iteration} to {checkpoint_path}")*

    def validation(self):
        """Validation during the PINN fine-tuning step"""
        self.model.eval()
        with torch.no_grad():
            out = self.model(
                self.dgl_graph.ndata["x"], self.dgl_graph.edata["x"], self.dgl_graph
            )
            model_out = {k: out[:, i : i + 1] for i, k in enumerate(["u", "v", "p"])}
            pred_u, pred_v, pred_p = (
                model_out["u"],
                model_out["v"],
                model_out["p"],
            )
            error_u = torch.linalg.norm(self.ref_u - pred_u) / torch.linalg.norm(
                self.ref_u
            )
            error_v = torch.linalg.norm(self.ref_v - pred_v) / torch.linalg.norm(
                self.ref_v
            )
            error_p = torch.linalg.norm(self.ref_p - pred_p) / torch.linalg.norm(
                self.ref_p
            )
            wandb.log(
                {
                    "test_u_error (%)": error_u.detach().cpu().numpy(),
                    "test_v_error (%)": error_v.detach().cpu().numpy(),
                    "test_p_error (%)": error_p.detach().cpu().numpy(),
                }
            )
            return error_u, error_v, error_p


@hydra.main(version_base="1.3", config_path="conf", config_name="config")
def main(cfg: DictConfig) -> None:
    # CUDA support
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    # initialize loggers
    initialize_wandb(
        project="PhysicsNeMo-Launch",
        entity="PhysicsNeMo",
        name="Stokes-Physics-Informed-Fine-Tuning",
        group="Stokes-DDP-Group",
        mode=cfg.wandb_mode,
    )

    logger = PythonLogger("main")  # General python logger
    logger.file_logging()

    # Get dataset
    path = os.path.join(to_absolute_path(cfg.results_dir), cfg.graph_path)

    # get_dataset() function here provides the true values (ref_*) and the gnn
    # predictions (gnn_*) along with other data required for the PINN training.
    (
        ref_u,
        ref_v,
        ref_p,
        gnn_u,
        gnn_v,
        gnn_p,
        coords,
        coords_inflow,
        coords_outflow,
        coords_wall,
        coords_polygon,
        nu,
        dgl_graph,
    ) = get_dataset(path, return_graph=True)
    coords_noslip = np.concatenate([coords_wall, coords_polygon], axis=0)

    dgl_graph = dgl_graph.to(device)

    # Initialize model
    pi_fine_tuner = PhysicsInformedFineTuner(
        cfg,
        device,
        gnn_u,
        gnn_v,
        gnn_p,
        coords,
        coords_inflow,
        coords_noslip,
        nu,
        ref_u,
        ref_v,
        ref_p,
        dgl_graph,
    )

    logger.info("Inference (with physics-informed training for fine-tuning) started...")
    for iters in range(cfg.pi_iters):
        # Start timing the iteration
        start_iter_time = time.time()

        (
            loss_u,
            loss_v,
            loss_p,
            loss_u_in,
            loss_v_in,
            loss_u_noslip,
            loss_v_noslip,
            loss_mom_u,
            loss_mom_v,
            loss_cont,
            current_lr,
        ) = pi_fine_tuner.train()

        if iters % 100 == 0:
            error_u, error_v, error_p = pi_fine_tuner.validation()

            # Print losses
            logger.info(f"Iteration: {iters}")
            logger.info(f"Loss u: {loss_u.detach().cpu().numpy():.3e}")
            logger.info(f"Loss v: {loss_v.detach().cpu().numpy():.3e}")
            logger.info(f"Loss p: {loss_p.detach().cpu().numpy():.3e}")
            logger.info(f"Loss u_in: {loss_u_in.detach().cpu().numpy():.3e}")
            logger.info(f"Loss v_in: {loss_v_in.detach().cpu().numpy():.3e}")
            logger.info(f"Loss u noslip: {loss_u_noslip.detach().cpu().numpy():.3e}")
            logger.info(f"Loss v noslip: {loss_v_noslip.detach().cpu().numpy():.3e}")
            logger.info(f"Loss momentum u: {loss_mom_u.detach().cpu().numpy():.3e}")
            logger.info(f"Loss momentum v: {loss_mom_v.detach().cpu().numpy():.3e}")
            logger.info(f"Loss continuity: {loss_cont.detach().cpu().numpy():.3e}")
            logger.info(f"Learning Rate: {current_lr}")

            # Print errors
            logger.info(f"Error u: {error_u:.3e}")
            logger.info(f"Error v: {error_v:.3e}")
            logger.info(f"Error p: {error_p:.3e}")

            # Print iteration time
            end_iter_time = time.time()
            logger.info(
                f"This iteration took {end_iter_time - start_iter_time:.2f} seconds"
            )
            logger.info("-" * 50)
            
             

            **if iters > 0 :**
**                pi_fine_tuner.save_checkpoint1(iteration=iters,checkpoint_dir="./workspace/checkpoints")**

**            save_checkpoint(**
**                to_absolute_path(cfg.finetune_ckpt),**
**                models=pi_fine_tuner.model,**
**                optimizer=pi_fine_tuner.optimizer,**
**                scheduler=pi_fine_tuner.scheduler,**
**                epoch=cfg.pi_iters,**
**            )** 
    
    logger.info("Physics-informed fine-tuning training completed!")

    # # Save the model state
    # model_save_path = os.path.join(to_absolute_path(cfg.results_dir), "pi_fine_tuner.pt")
    # torch.save(pi_fine_tuner.model.state_dict(), model_save_path)
    # logger.info(f"Model saved at {model_save_path}")    

    

    # Save results
    # Final inference call after fine-tuning predictions using the PINN model
    with torch.no_grad():
        out = pi_fine_tuner.model(dgl_graph.ndata["x"], dgl_graph.edata["x"], dgl_graph)
        results_int_inf = {k: out[:, i : i + 1] for i, k in enumerate(["u", "v", "p"])}
        pred_u_inf, pred_v_inf, pred_p_inf = (
            results_int_inf["u"],
            results_int_inf["v"],
            results_int_inf["p"],
        )

        pred_u_inf = pred_u_inf.detach().cpu().numpy()
        pred_v_inf = pred_v_inf.detach().cpu().numpy()
        pred_p_inf = pred_p_inf.detach().cpu().numpy()

        polydata = pv.read(path)
        polydata["filtered_u"] = pred_u_inf
        polydata["filtered_v"] = pred_v_inf
        polydata["filtered_p"] = pred_p_inf
        print(path)
        polydata.save(path)

    logger.info("Inference completed!")


if __name__ == "__main__":
    main()

Inferencing Finetune model code :

import os
import hydra
import torch
import numpy as np
import pyvista as pv
from omegaconf import DictConfig
from hydra.utils import to_absolute_path
from physicsnemo.models.meshgraphnet import MeshGraphNet
from physicsnemo.launch.logging import PythonLogger
from physicsnemo.launch.utils import load_checkpoint

from dgl import DGLGraph
from utils import get_dataset  # Assuming this handles graph construction

class PhysicsInformedInference:
    def __init__(self, cfg: DictConfig, logger):
        self.logger = logger
        self.cfg = cfg
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Recreate model architecture exactly as in physics-informed training
        self.model = MeshGraphNet(
            cfg.input_dim_nodes+128 ,  # Must match training (+128 for Fourier features)
            cfg.input_dim_edges,
            cfg.output_dim,
            aggregation=cfg.aggregation,
            hidden_dim_node_encoder=cfg.hidden_dim_node_encoder,
            hidden_dim_edge_encoder=cfg.hidden_dim_edge_encoder,
            hidden_dim_node_decoder=cfg.hidden_dim_node_decoder,
        ).to(self.device)
        
        # Load saved weights
        # model_path = os.path.join(to_absolute_path(cfg.results_dir), "finetuner_iter_9900.pt")
        # self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        # Load saved weights correctly
        model_path = os.path.join(to_absolute_path(cfg.results_dir), "finetuner_iter_9900.pt")
        checkpoint = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state_dict"])


        self.model.eval()
        # self.logger.info(f"Loaded physics-informed model from {model_path}")

    def process_input(self, input_path):
        """Load and preprocess input data to match training format"""
        # Reuse dataset loading function from training
        _, _, _, _, _, _, coords, coords_inflow, _, _, _, nu, dgl_graph = get_dataset(
            input_path, return_graph=True
        )
        
        # Move graph to device
        dgl_graph = dgl_graph.to(self.device)
        return dgl_graph, coords, nu

    def run_inference(self, input_path, output_path):
        """Run model inference on new input file"""
        # Load and preprocess data
        dgl_graph, coords, nu = self.process_input(input_path)
        
        # Run inference
        with torch.no_grad():
            out = self.model(
                dgl_graph.ndata["x"],
                dgl_graph.edata["x"],
                dgl_graph
            )
            print("Inference output shape:", out)
          
            results_int_inf = {k: out[:, i : i + 1] for i, k in enumerate(["u", "v", "p"])}
            print("Results keys:", results_int_inf)

            pred_u_inf, pred_v_inf, pred_p_inf = (
                results_int_inf["u"],
                results_int_inf["v"],
                results_int_inf["p"],
          )
            print("Predictions shapes:", pred_u_inf, pred_v_inf, pred_p_inf)

            pred_u_inf = pred_u_inf.detach().cpu().numpy()
            pred_v_inf = pred_v_inf.detach().cpu().numpy()
            pred_p_inf = pred_p_inf.detach().cpu().numpy()

            print("Converted predictions to numpy arrays")
            print("Predictions shapes after conversion:", pred_u_inf, pred_v_inf, pred_p_inf)

            polydata = pv.read(input_path)
            polydata["filtered_u"] = pred_u_inf
            polydata["filtered_v"] = pred_v_inf
            polydata["filtered_p"] = pred_p_inf
            print(input_path)
            polydata.save(input_path)

        self.logger.info("Inference completed!")


        # pred_u, pred_v, pred_p = out[:, 0], out[:, 1], out[:, 2]
        
        # # Convert to numpy
        # pred_u = pred_u.cpu().numpy()
        # pred_v = pred_v.cpu().numpy()
        # pred_p = pred_p.cpu().numpy()
        
        # # Save results
        # polydata = pv.read(input_path)
        # polydata["f_u"] = pred_u
        # polydata["f_v"] = pred_v
        # polydata["f_p"] = pred_p
        
        # polydata.save(output_path)
        # self.logger.info(f"Saved predictions to {output_path}")

@hydra.main(version_base="1.3", config_path="conf", config_name="config")
def main(cfg: DictConfig) -> None:
    logger = PythonLogger("physics_informed_inference")
    logger.file_logging()
    logger.info("Starting physics-informed inference")
    
    # Initialize inference engine
    inference_engine = PhysicsInformedInference(cfg, logger)
    
    # Process input file
    input_file = to_absolute_path("graph_8.vtp")
    output_file = to_absolute_path("predictions.vtp")
    
    inference_engine.run_inference(input_file, output_file)
    logger.info("Inference completed successfully")

if __name__ == "__main__":
    main()

Questions :

  1. Why after finetuning the model the example not saving it?
  2. How to inference the save finetuned model

Thankyou for support !

1 Like