Loading Two Learned Policies Applies Only One Observation

Dear All,

I’m trying to train a new policy that determines more abstract actions using two learned policies. I referred to the navigation demo code and tried to adapt it for the two policies, but the observation of the first loaded policy is being applied to the second policy as well. Is there a way to avoid this error?
Probably, rolling_low_level_obs = self._rolling_low_level_obs_manager.compute_group("rolling_policy") is not being processed correctly.
Thank you!

Error Code:

Error executing job with overrides: []
Traceback (most recent call last):
  File "/home/robot/IsaacLab_v1.2.0/IsaacLab/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/utils/hydra.py", line 99, in hydra_main
    func(env_cfg, agent_cfg, *args, **kwargs)
  File "/home/robot/IsaacLab_v1.2.0/IsaacLab/source/standalone/workflows/rsl_rl/train.py", line 144, in main
    runner.learn(num_learning_iterations=agent_cfg.max_iterations, init_at_random_ep_len=True)
  File "/home/robot/anaconda3/envs/iccar2025/lib/python3.10/site-packages/rsl_rl/runners/on_policy_runner.py", line 112, in learn
    obs, rewards, dones, infos = self.env.step(actions.to(self.env.device))
  File "/home/robot/IsaacLab_v1.2.0/IsaacLab/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/utils/wrappers/rsl_rl/vecenv_wrapper.py", line 177, in step
    obs_dict, rew, terminated, truncated, extras = self.env.step(actions)
  File "/home/robot/anaconda3/envs/iccar2025/lib/python3.10/site-packages/gymnasium/wrappers/record_video.py", line 166, in step
    ) = self.env.step(action)
  File "/home/robot/anaconda3/envs/iccar2025/lib/python3.10/site-packages/gymnasium/wrappers/order_enforcing.py", line 56, in step
    return self.env.step(action)
  File "/home/robot/IsaacLab_v1.2.0/IsaacLab/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/manager_based_rl_env.py", line 169, in step
    self.action_manager.apply_action()
  File "/home/robot/IsaacLab_v1.2.0/IsaacLab/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/action_manager.py", line 328, in apply_action
    term.apply_actions()
  File "/home/robot/IsaacLab_v1.2.0/IsaacLab/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/locomotion_selector/mdp/pre_trained_policy_action.py", line 119, in apply_actions
    self.rolling_low_level_actions[:] = self.policy_2(rolling_low_level_obs)
  File "/home/robot/.local/share/ov/pkg/isaac-sim-4.2.0/exts/omni.isaac.ml_archive/pip_prebundle/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/robot/.local/share/ov/pkg/isaac-sim-4.2.0/exts/omni.isaac.ml_archive/pip_prebundle/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/omni/isaac/lab_tasks/utils/wrappers/rsl_rl/exporter.py", line 13, in forward
    actor = self.actor
    normalizer = self.normalizer
    _0 = (actor).forward((normalizer).forward(x, ), )
          ~~~~~~~~~~~~~~ <--- HERE
    return _0
  def reset(self: __torch__.omni.isaac.lab_tasks.utils.wrappers.rsl_rl.exporter._TorchPolicyExporter) -> NoneType:
  File "code/__torch__/torch/nn/modules/container.py", line 22, in forward
    _5 = getattr(self, "5")
    _6 = getattr(self, "6")
    input0 = (_0).forward(input, )
              ~~~~~~~~~~~ <--- HERE
    input1 = (_1).forward(input0, )
    input2 = (_2).forward(input1, )
  File "code/__torch__/torch/nn/modules/linear.py", line 14, in forward
    weight = self.weight
    bias = self.bias
    return torch.linear(input, weight, bias)
           ~~~~~~~~~~~~ <--- HERE
class Identity(Module):
  __parameters__ = []

Traceback of TorchScript, original code (most recent call last):
  File "/home/robot/IsaacLab_v1.2.0/IsaacLab/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/utils/wrappers/rsl_rl/exporter.py", line 76, in forward
    def forward(self, x):
        return self.actor(self.normalizer(x))
               ~~~~~~~~~~ <--- HERE
  File "/home/robot/.local/share/ov/pkg/isaac-sim-4.2.0/exts/omni.isaac.ml_archive/pip_prebundle/torch/nn/modules/container.py", line 219, in forward
    def forward(self, input):
        for module in self:
            input = module(input)
                    ~~~~~~ <--- HERE
        return input
  File "/home/robot/.local/share/ov/pkg/isaac-sim-4.2.0/exts/omni.isaac.ml_archive/pip_prebundle/torch/nn/modules/linear.py", line 117, in forward
    def forward(self, input: Tensor) -> Tensor:
        return F.linear(input, self.weight, self.bias)
               ~~~~~~~~ <--- HERE
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1024x147 and 66x512)

My modified code:

# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import torch
from dataclasses import MISSING
from typing import TYPE_CHECKING

import omni.isaac.lab.utils.math as math_utils
from omni.isaac.lab.assets import Articulation
from omni.isaac.lab.managers import ActionTerm, ActionTermCfg, ObservationGroupCfg, ObservationManager, CommandManager
from omni.isaac.lab.markers import VisualizationMarkers
from omni.isaac.lab.markers.config import BLUE_ARROW_X_MARKER_CFG, GREEN_ARROW_X_MARKER_CFG
from omni.isaac.lab.utils import configclass
from omni.isaac.lab.utils.assets import check_file_path, read_file

if TYPE_CHECKING:
    from omni.isaac.lab.envs import ManagerBasedRLEnv


class PreTrainedPolicyAction(ActionTerm):
    r"""Pre-trained policy action term.

    This action term infers a pre-trained policy and applies the corresponding low-level actions to the robot.
    The raw actions correspond to the commands for the pre-trained policy.

    """

    cfg: PreTrainedPolicyActionCfg
    """The configuration of the action term."""

    def __init__(self, cfg: PreTrainedPolicyActionCfg, env: ManagerBasedRLEnv) -> None:
        # initialize the action term
        super().__init__(cfg, env)

        self.robot: Articulation = env.scene[cfg.asset_name]
        #self.env = env


        # load policy
        #if not check_file_path(cfg.policy_path):
        #    raise FileNotFoundError(f"Policy file '{cfg.policy_path}' does not exist.")
        #file_bytes = read_file(cfg.policy_path)
        #self.policy = torch.jit.load(file_bytes).to(env.device).eval()
        for i, policy_path in enumerate([cfg.policy_path_1, cfg.policy_path_2], start=1):
            if not check_file_path(policy_path):
                raise FileNotFoundError(f"Policy file '{policy_path}' does not exist.")
            file_bytes = read_file(policy_path)
            setattr(self, f"policy_{i}", torch.jit.load(file_bytes).to(env.device).eval())

        self._raw_actions = torch.zeros(self.num_envs, self.action_dim, device=self.device)

        #self.random_indices = torch.randint(0, 2, (self.num_envs,))

        # prepare low level actions
        # walking policy
        self._walking_low_level_action_term: ActionTerm = cfg.walking_low_level_actions.class_type(cfg.walking_low_level_actions, env)
        self.walking_low_level_actions = torch.zeros(self.num_envs, self._walking_low_level_action_term.action_dim, device=self.device)
        # rolling policy
        self._rolling_low_level_action_term: ActionTerm = cfg.rolling_low_level_actions.class_type(cfg.rolling_low_level_actions, env)
        self.rolling_low_level_actions = torch.zeros(self.num_envs, self._rolling_low_level_action_term.action_dim, device=self.device)

        # remap some of the low level observations to internal observations
        # walking policy
        cfg.walking_low_level_observations.actions.func = lambda dummy_env: self.walking_low_level_actions
        cfg.walking_low_level_observations.actions.params = dict()
        cfg.walking_low_level_observations.velocity_commands.func = lambda dummy_env: self._raw_actions
        cfg.walking_low_level_observations.velocity_commands.params = dict()
        # rolling policy
        cfg.rolling_low_level_observations.actions.func = lambda dummy_env: self.rolling_low_level_actions
        cfg.rolling_low_level_observations.actions.params = dict()
        cfg.rolling_low_level_observations.velocity_commands.func = lambda dummy_env: self._raw_actions
        cfg.rolling_low_level_observations.velocity_commands.params = dict()

        # add the low level observations to the observation manager
        # walking policy
        self._walking_low_level_obs_manager = ObservationManager({"walking_policy": cfg.walking_low_level_observations}, env)
        # rolling policy
        self._rolling_low_level_obs_manager = ObservationManager({"rolling_policy": cfg.rolling_low_level_observations}, env)

        self._counter = 0

    """
    Properties.
    """

    @property
    def action_dim(self) -> int:
        return 3

    @property
    def raw_actions(self) -> torch.Tensor:
        return self._raw_actions

    @property
    def processed_actions(self) -> torch.Tensor:
        return self.raw_actions

    """
    Operations.
    """

    def process_actions(self, actions: torch.Tensor):
        self._raw_actions[:] = actions

    def apply_actions(self):
        if self._counter % self.cfg.low_level_decimation == 0:
            walking_low_level_obs = self._walking_low_level_obs_manager.compute_group("walking_policy")
            rolling_low_level_obs = self._rolling_low_level_obs_manager.compute_group("rolling_policy")

            self.walking_low_level_actions[:] = self.policy_1(walking_low_level_obs)
            self.rolling_low_level_actions[:] = self.policy_2(rolling_low_level_obs)
            #print(f"get_command: {self.env.command_manager.get_command('pose_command')}")
            #print(f"get_command.shape: {self.env.command_manager.get_command('pose_command').shape}")

            self._walking_low_level_action_term.process_actions(self.walking_low_level_actions)
            self._rolling_low_level_action_term.process_actions(self.rolling_low_level_actions)
            #self._walking_low_level_action_term.process_actions(result)

            self._counter = 0
        self._walking_low_level_action_term.apply_actions()
        self._counter += 1

    """
    Debug visualization.
    """

    def _set_debug_vis_impl(self, debug_vis: bool):
        # set visibility of markers
        # note: parent only deals with callbacks. not their visibility
        if debug_vis:
            # create markers if necessary for the first tome
            if not hasattr(self, "base_vel_goal_visualizer"):
                # -- goal
                marker_cfg = GREEN_ARROW_X_MARKER_CFG.copy()
                marker_cfg.prim_path = "/Visuals/Actions/velocity_goal"
                marker_cfg.markers["arrow"].scale = (0.5, 0.5, 0.5)
                self.base_vel_goal_visualizer = VisualizationMarkers(marker_cfg)
                # -- current
                marker_cfg = BLUE_ARROW_X_MARKER_CFG.copy()
                marker_cfg.prim_path = "/Visuals/Actions/velocity_current"
                marker_cfg.markers["arrow"].scale = (0.5, 0.5, 0.5)
                self.base_vel_visualizer = VisualizationMarkers(marker_cfg)
            # set their visibility to true
            self.base_vel_goal_visualizer.set_visibility(True)
            self.base_vel_visualizer.set_visibility(True)
        else:
            if hasattr(self, "base_vel_goal_visualizer"):
                self.base_vel_goal_visualizer.set_visibility(False)
                self.base_vel_visualizer.set_visibility(False)

    def _debug_vis_callback(self, event):
        # check if robot is initialized
        # note: this is needed in-case the robot is de-initialized. we can't access the data
        if not self.robot.is_initialized:
            return
        # get marker location
        # -- base state
        base_pos_w = self.robot.data.root_pos_w.clone()
        base_pos_w[:, 2] += 0.5
        # -- resolve the scales and quaternions
        vel_des_arrow_scale, vel_des_arrow_quat = self._resolve_xy_velocity_to_arrow(self.raw_actions[:, :2])
        vel_arrow_scale, vel_arrow_quat = self._resolve_xy_velocity_to_arrow(self.robot.data.root_lin_vel_b[:, :2])
        # display markers
        self.base_vel_goal_visualizer.visualize(base_pos_w, vel_des_arrow_quat, vel_des_arrow_scale)
        self.base_vel_visualizer.visualize(base_pos_w, vel_arrow_quat, vel_arrow_scale)

    """
    Internal helpers.
    """

    def _resolve_xy_velocity_to_arrow(self, xy_velocity: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Converts the XY base velocity command to arrow direction rotation."""
        # obtain default scale of the marker
        default_scale = self.base_vel_goal_visualizer.cfg.markers["arrow"].scale
        # arrow-scale
        arrow_scale = torch.tensor(default_scale, device=self.device).repeat(xy_velocity.shape[0], 1)
        arrow_scale[:, 0] *= torch.linalg.norm(xy_velocity, dim=1) * 3.0
        # arrow-direction
        heading_angle = torch.atan2(xy_velocity[:, 1], xy_velocity[:, 0])
        zeros = torch.zeros_like(heading_angle)
        arrow_quat = math_utils.quat_from_euler_xyz(zeros, zeros, heading_angle)
        # convert everything back from base to world frame
        base_quat_w = self.robot.data.root_quat_w
        arrow_quat = math_utils.quat_mul(base_quat_w, arrow_quat)

        return arrow_scale, arrow_quat


@configclass
class PreTrainedPolicyActionCfg(ActionTermCfg):
    """Configuration for pre-trained policy action term.

    See :class:`PreTrainedPolicyAction` for more details.
    """

    class_type: type[ActionTerm] = PreTrainedPolicyAction
    """ Class of the action term."""
    asset_name: str = MISSING
    """Name of the asset in the environment for which the commands are generated."""
    #policy_path: str = MISSING
    policy_path_1: str = MISSING
    policy_path_2: str = MISSING
    """Path to the low level policy (.pt files)."""
    low_level_decimation: int = 4
    """Decimation factor for the low level action term."""
    #low_level_actions: ActionTermCfg = MISSING
    walking_low_level_actions: ActionTermCfg = MISSING
    rolling_low_level_actions: ActionTermCfg = MISSING
    """Low level action configuration."""
    #low_level_observations: ObservationGroupCfg = MISSING
    walking_low_level_observations: ObservationGroupCfg = MISSING
    rolling_low_level_observations: ObservationGroupCfg = MISSING
    """Low level observation configuration."""
    debug_vis: bool = True
    """Whether to visualize debug information. Defaults to False."""

The information of the learned policies is passed from the main code as shown below.
selector_env_cfg.py:

# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

import math

from omni.isaac.lab.envs import ManagerBasedRLEnvCfg
from omni.isaac.lab.managers import EventTermCfg as EventTerm
from omni.isaac.lab.managers import ObservationGroupCfg as ObsGroup
from omni.isaac.lab.managers import ObservationTermCfg as ObsTerm
from omni.isaac.lab.managers import RewardTermCfg as RewTerm
from omni.isaac.lab.managers import SceneEntityCfg
from omni.isaac.lab.managers import TerminationTermCfg as DoneTerm
from omni.isaac.lab.utils import configclass
from omni.isaac.lab.utils.assets import ISAACLAB_NUCLEUS_DIR
from omni.isaac.lab.managers import ManagerTermBase

#import omni.isaac.lab_tasks.manager_based.navigation.mdp as mdp
import omni.isaac.lab_tasks.manager_based.locomotion_selector.mdp as mdp
#from omni.isaac.lab_tasks.manager_based.locomotion.velocity.config.anymal_c.flat_env_cfg import AnymalCFlatEnvCfg
from omni.isaac.lab_tasks.manager_based.locomotion.velocity.config.AT3R.rough_walking_env_cfg import AT3RRoughEnvCfg as WalkingEnvCfg
from omni.isaac.lab_tasks.manager_based.locomotion.velocity.config.AT3R.rough_rolling_env_cfg import AT3RRoughEnvCfg as RollingEnvCfg

#LOW_LEVEL_ENV_CFG = AnymalCFlatEnvCfg()
WALKING_ENV_CFG = WalkingEnvCfg()
ROLLING_ENV_CFG = RollingEnvCfg()


@configclass
class EventCfg:
    """Configuration for events."""

    reset_base = EventTerm(
        func=mdp.reset_root_state_uniform,
        mode="reset",
        params={
            "pose_range": {"x": (-0.0, 0.0), "y": (-0.0, 0.0), "yaw": (-3.14, 3.14)},
            "velocity_range": {
                "x": (-0.0, 0.0),
                "y": (-0.0, 0.0),
                "z": (-0.0, 0.0),
                "roll": (-0.0, 0.0),
                "pitch": (-0.0, 0.0),
                "yaw": (-0.0, 0.0),
            },
        },
    )


@configclass
class ActionsCfg:
    """Action terms for the MDP."""

    #pre_trained_policy_action: mdp.PreTrainedPolicyActionCfg = mdp.PreTrainedPolicyActionCfg(
    #    asset_name="robot",
    #    policy_path=f"{ISAACLAB_NUCLEUS_DIR}/Policies/ANYmal-C/Blind/policy.pt",
    #    low_level_decimation=4,
    #    low_level_actions=LOW_LEVEL_ENV_CFG.actions.joint_pos,
    #    low_level_observations=LOW_LEVEL_ENV_CFG.observations.policy,
    #)
    pre_trained_policy_action = mdp.PreTrainedPolicyActionCfg = mdp.PreTrainedPolicyActionCfg(
        asset_name="robot",
        policy_path_1=f"/home/robot/IsaacLab_v1.2.0/IsaacLab/logs/policies/walking_policy/2024-12-11_15-33-54/exported/policy.pt",
        policy_path_2=f"/home/robot/IsaacLab_v1.2.0/IsaacLab/logs/policies/rolling_policy/2024-12-13_13-41-19/exported/policy.pt",
        low_level_decimation=4,
        walking_low_level_actions=WALKING_ENV_CFG.actions.joint_pos,
        walking_low_level_observations=WALKING_ENV_CFG.observations.policy,
        rolling_low_level_actions=ROLLING_ENV_CFG.actions.joint_pos,
        rolling_low_level_observations=WALKING_ENV_CFG.observations.policy,
    )

@configclass
class ObservationsCfg:
    """Observation specifications for the MDP."""

    @configclass
    class PolicyCfg(ObsGroup):
        """Observations for policy group."""

        # observation terms (order preserved)
        base_lin_vel = ObsTerm(func=mdp.base_lin_vel)
        projected_gravity = ObsTerm(func=mdp.projected_gravity)
        pose_command = ObsTerm(func=mdp.generated_commands, params={"command_name": "pose_command"})

    # observation groups
    policy: PolicyCfg = PolicyCfg()


@configclass
class RewardsCfg:
    """Reward terms for the MDP."""

    termination_penalty = RewTerm(func=mdp.is_terminated, weight=-400.0)
    position_tracking = RewTerm(
        func=mdp.position_command_error_tanh,
        weight=0.5,
        params={"std": 2.0, "command_name": "pose_command"},
    )
    position_tracking_fine_grained = RewTerm(
        func=mdp.position_command_error_tanh,
        weight=0.5,
        params={"std": 0.2, "command_name": "pose_command"},
    )
    orientation_tracking = RewTerm(
        func=mdp.heading_command_error_abs,
        weight=-0.2,
        params={"command_name": "pose_command"},
    )
    


@configclass
class CommandsCfg:
    """Command terms for the MDP."""

    pose_command = mdp.UniformPose2dCommandCfg(
        asset_name="robot",
        simple_heading=False,
        resampling_time_range=(8.0, 8.0),
        debug_vis=False,
        ranges=mdp.UniformPose2dCommandCfg.Ranges(pos_x=(-3.0, 3.0), pos_y=(-3.0, 3.0), heading=(-math.pi, math.pi)),
    )
    #pose_command = mdp.NullCommandCfg()
    #print(f"num_envs: {ManagerTermBase.num_envs}")


@configclass
class TerminationsCfg:
    """Termination terms for the MDP."""

    time_out = DoneTerm(func=mdp.time_out, time_out=True)
    base_contact = DoneTerm(
        func=mdp.illegal_contact,
        params={"sensor_cfg": SceneEntityCfg("contact_forces", body_names="MainBody"), "threshold": 1.0},
    )


@configclass
#class NavigationEnvCfg(ManagerBasedRLEnvCfg):
class SelectorEnvCfg(ManagerBasedRLEnvCfg):
    """Configuration for the navigation environment."""

    # environment settings
    #scene: SceneEntityCfg = LOW_LEVEL_ENV_CFG.scene
    scene: SceneEntityCfg = WALKING_ENV_CFG.scene
    actions: ActionsCfg = ActionsCfg()
    observations: ObservationsCfg = ObservationsCfg()
    events: EventCfg = EventCfg()
    # mdp settings
    commands: CommandsCfg = CommandsCfg()
    rewards: RewardsCfg = RewardsCfg()
    terminations: TerminationsCfg = TerminationsCfg()

    def __post_init__(self):
        """Post initialization."""

        #self.sim.dt = LOW_LEVEL_ENV_CFG.sim.dt
        #self.sim.render_interval = LOW_LEVEL_ENV_CFG.decimation
        #self.decimation = LOW_LEVEL_ENV_CFG.decimation * 10
        self.sim.dt = WALKING_ENV_CFG.sim.dt
        self.sim.render_interval = WALKING_ENV_CFG.decimation
        self.decimation = WALKING_ENV_CFG.decimation * 10
        self.episode_length_s = self.commands.pose_command.resampling_time_range[1]

        if self.scene.height_scanner is not None:
            self.scene.height_scanner.update_period = (
                self.actions.pre_trained_policy_action.low_level_decimation * self.sim.dt
            )
        if self.scene.contact_forces is not None:
            self.scene.contact_forces.update_period = self.sim.dt


#class NavigationEnvCfg_PLAY(NavigationEnvCfg):
class SelectorEnvCfg_PLAY(SelectorEnvCfg):
    def __post_init__(self) -> None:
        # post init of parent
        super().__post_init__()

        # make a smaller scene for play
        self.scene.num_envs = 16
        self.scene.env_spacing = 2.5
        # disable randomization for play
        self.observations.policy.enable_corruption = False

Isaac Sim Version

4.2.0
4.1.0
4.0.0
2023.1.1
2023.1.0-hotfix.1
Other (please specify):

Isaac Lab Version (if applicable)

1.2
1.1
1.0
Other (please specify):

Operating System

Ubuntu 22.04
Ubuntu 20.04
Windows 11
Windows 10
Other (please specify):

GPU Information

  • Model: RTX A4000
  • Driver Version:

Thank you for posting this. Please, try Isaac Lab 1.3, and if you still face this issue, open a question in the issues of the Isaac Lab GitHub repository as you have been doing if you haven’t already. We are concentrating our efforts on Isaac Lab user support in that repository. Thank you.

Hello @phennings,

Thank you for your reply!
Regarding this matter, it was caused by my coding mistake. The two trained policies are now working properly.

1 Like