Isaacgym preview 4 actor root state returns nans with isaacgymenvs-style task

I was having the same issue with a more complex model so I wrote a simple task loading a cylinder (URDF) and dropping it onto a plane. After the first simulation step, the position and velocity of the actor come back as NaNs (after the call to refresh the actor root state tensor). By increasing controlFrequencyInv I can see that there is no instability in the physics - the cylinder falls to the ground and comes to a rest. The error is thrown from the call to the mlp, as the NaNs propagate upstream, but the NaNs first appear after refresh_actor_root_state. Why might this be happening? I follow the isaacgymenvs examples exactly, so I believe my physics parameters are reasonable. I make sure my tensors aren’t being garbage collected etc but am not sure how to debug further. Any advice would be appreciated!

Code below:

import numpy as np
import os
import torch

from isaacgym import gymtorch
from isaacgym import gymapi
from isaacgym.torch_utils import *

from isaacgymenvs.utils.torch_jit_utils import *
from isaacgymenvs.tasks.base.vec_task import VecTask

class Hallway(VecTask):
    def __init__(self, cfg, rl_device, sim_device, graphics_device_id, headless, virtual_screen_capture, force_render):
        self.cfg = cfg

        self.heading_weight = self.cfg["env"]["headingWeight"]
        self.actions_cost_scale = self.cfg["env"]["actionsCost"]
        self.velocity_cost_scale = self.cfg["env"]["velocityCost"]
        self.ang_velocity_cost_scale = self.cfg["env"]["angularVelocityCost"]
        self.death_cost = self.cfg["env"]["deathCost"]
        self.boundary = self.cfg["env"]["boundary"]
        self.force_scale = self.cfg["env"]["forceScale"]
        self.torque_scale = self.cfg["env"]["torqueScale"]
        self.max_range = self.cfg["env"]["maxRange"]

        self.debug_viz = self.cfg["env"]["enableDebugVis"]
        self.plane_static_friction = self.cfg["env"]["plane"]["staticFriction"]
        self.plane_dynamic_friction = self.cfg["env"]["plane"]["dynamicFriction"]
        self.plane_restitution = self.cfg["env"]["plane"]["restitution"]

        self.max_episode_length = self.cfg["env"]["episodeLength"]

        self.cfg["env"]["numObservations"] = 14
        self.cfg["env"]["numActions"] = 3

        super().__init__(config=self.cfg, rl_device=rl_device, sim_device=sim_device, graphics_device_id=graphics_device_id, headless=headless, virtual_screen_capture=virtual_screen_capture, force_render=force_render)
        
        if self.viewer != None:
            cam_pos = gymapi.Vec3(50.0, 25.0, 2.4)
            cam_target = gymapi.Vec3(45.0, 25.0, 0.0)
            self.gym.viewer_camera_look_at(self.viewer, None, cam_pos, cam_target)

        self._root_tensor = self.gym.acquire_actor_root_state_tensor(self.sim) #[num_bodies,13]
        self.root_tensor = gymtorch.wrap_tensor(self._root_tensor)
        self.root_pos = self.root_tensor[:, 0:3] #positions
        self.root_ori = self.root_tensor[:, 3:7] #quaternions
        self.root_lvel = self.root_tensor[:, 7:10] #linear velocities
        self.root_avel = self.root_tensor[:, 10:13] #angular velocities
        self.initial_root_state = self.root_tensor.clone()
        self.initial_root_state[:, 7:13] = 0  #set velocities to zero

        #self._net_cf = self.gym.acquire_net_contact_force_tensor(self.sim) #[num_bodies,3]
        #self.net_cf = gymtorch.wrap_tensor(self._net_cf)

        self.targets = to_torch([0, self.max_range, 0], device=self.device).repeat((self.num_envs, 1))
        self.target_angs = to_torch([0], device=self.device).repeat((self.num_envs, 1))
        self.target_dirs = to_torch([0, 1, 0], device=self.device).repeat((self.num_envs, 1))
        self.dt = self.cfg["sim"]["dt"]
        self.goal_vel = self.cfg["env"]["velocityGoal"] * torch.ones(self.num_envs,1,device=self.device)
        self.potentials = to_torch([-10./self.dt], device=self.device).repeat(self.num_envs)
        self.prev_potentials = self.potentials.clone()
        
        self.gym.refresh_actor_root_state_tensor(self.sim)

    def create_sim(self):
        # implement sim set up and environment creation here
        #    - set up-axis
        #    - call super().create_sim with device args (see docstring)
        #    - create ground plane
        #    - set up environments
        self.up_axis_idx = 2 # index of up axis: Y=1, Z=2
        self.sim = super().create_sim(self.device_id, self.graphics_device_id, self.physics_engine, self.sim_params)

        self._create_ground_plane()
        self._create_envs(self.num_envs, self.cfg["env"]['envSpacing'], int(np.sqrt(self.num_envs)))

    def _create_ground_plane(self):
        plane_params = gymapi.PlaneParams()
        plane_params.normal = gymapi.Vec3(0.0, 0.0, 1.0)
        plane_params.static_friction = self.plane_static_friction
        plane_params.dynamic_friction = self.plane_dynamic_friction
        plane_params.restitution = self.plane_restitution
        self.gym.add_ground(self.sim, plane_params)

    def _create_envs(self, num_envs, spacing, num_per_row):
        lower = gymapi.Vec3(-spacing, -spacing, 0)
        upper = gymapi.Vec3(spacing, spacing, spacing)

        asset_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../assets')
        asset_file = "urdf/bumpybot.urdf"

        if "asset" in self.cfg["env"]:
            asset_file = self.cfg["env"]["asset"].get("assetFileName", asset_file)

        asset_path = os.path.join(asset_root, asset_file)
        asset_root = os.path.dirname(asset_path)
        asset_file = os.path.basename(asset_path)

        asset_options = gymapi.AssetOptions()
        asset_options.slices_per_cylinder = 8
        asset_options.collapse_fixed_joints = True
        asset = self.gym.load_asset(self.sim, asset_root, asset_file, asset_options)

        start_pose = gymapi.Transform()
        start_pose.p = gymapi.Vec3(*get_axis_params(0.5, self.up_axis_idx))
        start_pose.r = gymapi.Quat.from_euler_zyx(0, 0, 0)

        self.handles = []
        self.envs = []

        for i in range(self.num_envs):
            # create env instance
            env_ptr = self.gym.create_env(self.sim, lower, upper, num_per_row)
            handle = self.gym.create_actor(env_ptr, asset, start_pose, "bumpybot", i)

            self.envs.append(env_ptr)
            self.handles.append(handle)

    def _compute_reward(self, actions):
        self.rew_buf[:], self.reset_buf = compute_reward(
            self.obs_buf,
            self.reset_buf,
            self.progress_buf,
            self.actions,
            self.heading_weight,
            self.potentials,
            self.prev_potentials,
            self.actions_cost_scale,
            self.velocity_cost_scale,
            self.ang_velocity_cost_scale,
            self.boundary,
            self.goal_vel,
            self.death_cost,
            self.max_episode_length
            )

    def _compute_observations(self):
        self.gym.refresh_actor_root_state_tensor(self.sim)

        self.obs_buf[:], self.potentials[:], self.prev_potentials[:] = compute_observations(
            self.root_tensor,
            self.targets,
            self.target_angs,
            self.goal_vel,
            self.potentials,
            self.prev_potentials,
            self.dt,
            self.actions,
            self.max_range
            )

    def reset_idx(self, env_ids):

        positions = torch_rand_float(-0.2, 0.2, (len(env_ids), 2), device=self.device)
        pose = torch.zeros_like(self.initial_root_state)
        pose[:,:2] = positions[:, :]

        random_root = self.initial_root_state[env_ids] + pose

        env_ids_int32 = env_ids.to(dtype=torch.int32)
        self.gym.set_actor_root_state_tensor_indexed(self.sim,
                                                     gymtorch.unwrap_tensor(random_root),
                                                     gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32))

        to_target = self.targets[env_ids] - random_root[env_ids, :3]
        to_target[:, self.up_axis_idx] = 0
        self.prev_potentials[env_ids] = -torch.norm(to_target, p=2, dim=-1) / self.dt
        self.potentials[env_ids] = self.prev_potentials[env_ids].clone()

        self.progress_buf[env_ids] = 0
        self.reset_buf[env_ids] = 0

    def pre_physics_step(self, actions):
        # implement pre-physics simulation code here
        #    - e.g. apply actions
        # apply  forces
        self.actions = actions.to(self.device).clone()

        forces = torch.zeros_like(self.actions)
        forces[:, :2] = self.actions[:, :2]

        torques = torch.zeros_like(self.actions)
        torques[:, 2] = self.actions[:, 2]

        self.gym.apply_rigid_body_force_tensors(self.sim, gymtorch.unwrap_tensor(forces), gymtorch.unwrap_tensor(torques), gymapi.ENV_SPACE)

    def post_physics_step(self):
        # implement post-physics simulation code here
        #    - e.g. compute reward, compute observations
        self.progress_buf += 1

        env_ids = self.reset_buf.nonzero(as_tuple=False).flatten()
        if len(env_ids) > 0:
            self.reset_idx(env_ids)

        self._compute_observations()
        self._compute_reward(self.actions)

        # debug viz
        if self.viewer and self.debug_viz:
            self.gym.clear_lines(self.viewer)

            points = []
            colors = []
            for i in range(self.num_envs):
                origin = self.gym.get_env_origin(self.envs[i])
                pose = self.root_states[:, 0:3][i].cpu().numpy()
                glob_pos = gymapi.Vec3(origin.x + pose[0], origin.y + pose[1], origin.z + pose[2])
                points.append([glob_pos.x, glob_pos.y, glob_pos.z, glob_pos.x + 4 * self.heading_vec[i, 0].cpu().numpy(),
                               glob_pos.y + 4 * self.heading_vec[i, 1].cpu().numpy(),
                               glob_pos.z + 4 * self.heading_vec[i, 2].cpu().numpy()])
                colors.append([0.97, 0.1, 0.06])
                points.append([glob_pos.x, glob_pos.y, glob_pos.z, glob_pos.x + 4 * self.up_vec[i, 0].cpu().numpy(), glob_pos.y + 4 * self.up_vec[i, 1].cpu().numpy(),
                               glob_pos.z + 4 * self.up_vec[i, 2].cpu().numpy()])
                colors.append([0.05, 0.99, 0.04])

            self.gym.add_lines(self.viewer, None, self.num_envs * 2, points, colors)

#####################################################################
###=========================jit functions=========================###
#####################################################################

@torch.jit.script
def compute_reward(
    obs_buf,
    reset_buf,
    progress_buf,
    actions,
    heading_weight,
    potentials,
    prev_potentials,
    actions_cost_scale,
    velocity_cost_scale,
    ang_velocity_cost_scale,
    boundary,
    goal_vel,
    death_cost,
    max_episode_length
    ):
    # type: (Tensor, Tensor, Tensor, Tensor, float, Tensor, Tensor, float, float, float, float, Tensor, float, float) -> Tuple[Tensor, Tensor]

    # reward from the direction headed
    heading_weight_tensor = torch.ones_like(obs_buf[:, 9]) * heading_weight
    heading_reward = torch.where(torch.abs(obs_buf[:, 9]) < 0.2, heading_weight_tensor, - heading_weight * obs_buf[:, 9])

    actions_cost = actions_cost_scale * torch.sum(actions ** 2, dim=-1)

    velocity_cost = velocity_cost_scale * torch.squeeze((torch.linalg.norm(obs_buf[:, 3:5], dim=-1).view(-1,1) - goal_vel) ** 2)

    ang_velocity_cost = ang_velocity_cost_scale * obs_buf[:,6]**2

    # reward for duration of being alive
    alive_reward = torch.ones_like(potentials) * 2.0
    progress_reward = potentials - prev_potentials

    total_reward = progress_reward + alive_reward + heading_reward \
        - actions_cost - velocity_cost - ang_velocity_cost

    # adjust reward for dead agents
    total_reward = torch.where(torch.abs(obs_buf[:, 0]) > boundary, torch.ones_like(total_reward) * death_cost, total_reward)

    # reset agents
    reset = torch.where(torch.abs(obs_buf[:, 0]) > boundary, torch.ones_like(reset_buf), reset_buf)
    reset = torch.where(progress_buf >= max_episode_length - 1, torch.ones_like(reset_buf), reset_buf) #last arg "reset"

    return total_reward, reset

@torch.jit.script
def compute_observations(
    root_states,
    targets,
    target_angs,
    goal_vel,
    potentials, 
    prev_potentials,
    dt,
    actions,
    max_range
    ):
    # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, float, Tensor, float) -> Tuple[Tensor, Tensor, Tensor]

    # x y th vx vy vth gx gy gth ang2target gv fx fy fth

    position = root_states[:, 0:3]
    rotation = root_states[:, 3:7]
    velocity = root_states[:, 7:10]
    ang_velocity = root_states[:, 10:13]

    _,_,yaw = get_euler_xyz(rotation)

    target_angle = torch.atan2(targets[:, 2] - position[:, 2],
                               targets[:, 0] - position[:, 0])
    angle_to_target = target_angle - yaw

    to_target = targets - position
    to_target[:, 2] = 0

    prev_potentials_new = potentials.clone()
    potentials = -torch.norm(to_target, p=2, dim=-1) / dt

    # Normalize Observations
    position_norm = position / max_range
    targets_norm = targets / max_range
    heading = normalize_angle(yaw).unsqueeze(-1)
    angle_to_target = normalize_angle(angle_to_target).unsqueeze(-1)

    # obs_buf shapes: 2, 1, 2, 1, 2, 1, 1, 1, num_acts(3)
    obs = torch.cat((position_norm[:,:2].view(-1, 2), heading, velocity[:,:2].view(-1,2),
            ang_velocity[:,2].view(-1,1), targets_norm[:,:2].view(-1,2), target_angs[:,0].view(-1,1),
            angle_to_target, goal_vel, actions), dim=-1) #14

    return obs, potentials, prev_potentials_new

with task config

# used to create the object
name: Hallway

physics_engine: ${..physics_engine}

# if given, will override the device setting in gym.
env: 
  numEnvs: ${resolve_default:4096,${...num_envs}}
  envSpacing: 2
  episodeLength: 500
  controlFrequencyInv: 100
  enableDebugVis: False

  clipActions: 1.0

  velocityGoal : 2.0
  maxRange: 10.0

  # reward parameters
  headingWeight: 0.5

  # cost parameters
  actionsCost: 0.01
  velocityCost: 0.01
  angularVelocityCost: 0.01
  deathCost: -1.0
  boundary: 2

  # control parameters
  forceScale: 1
  torqueScale: 1

  asset:
    assetFileName: "urdf/bumpybot.urdf"

  plane:
    staticFriction: 1.0
    dynamicFriction: 1.0
    restitution: 0.0

  # set to True if you use camera sensors in the environment
  enableCameraSensors: False

sim:
  dt: 0.0166 # 1/60 s
  substeps: 2
  up_axis: "z"
  use_gpu_pipeline: ${eq:${...pipeline},"gpu"}
  gravity: [0.0, 0.0, -9.81]
  physx:
    num_threads: ${....num_threads}
    solver_type: ${....solver_type}
    use_gpu: ${contains:"cuda",${....sim_device}} # set to False to run on CPU
    num_position_iterations: 4
    num_velocity_iterations: 0
    contact_offset: 0.02
    rest_offset: 0.0
    bounce_threshold_velocity: 0.2
    max_depenetration_velocity: 10.0
    default_buffer_size_multiplier: 5.0
    max_gpu_contact_pairs: 8388608 # 8*1024*1024
    num_subscenes: ${....num_subscenes}
    contact_collection: 0 # 0: CC_NEVER (don't collect contact info), 1: CC_LAST_SUBSTEP (collect only contacts on last substep), 2: CC_ALL_SUBSTEPS (default - all contacts)

and train config

params:
  seed: ${...seed}

  algo:
    name: a2c_continuous

  model:
    name: continuous_a2c_logstd

  network:
    name: actor_critic
    separate: False

    space:
      continuous:
        mu_activation: None
        sigma_activation: None
        mu_init:
          name: default
        sigma_init:
          name: const_initializer
          val: 0
        fixed_sigma: True

    mlp:
      units: [200, 100, 50]
      activation: elu
      d2rl: False

      initializer:
        name: default
      regularizer:
        name: None

  load_checkpoint: ${if:${...checkpoint},True,False} # flag which sets whether to load the checkpoint
  load_path: ${...checkpoint} # path to the checkpoint to load

  config:
    name: ${resolve_default:Hallway,${....experiment}}
    full_experiment_name: ${.name}
    env_name: rlgpu
    multi_gpu: ${....multi_gpu}
    mixed_precision: True
    normalize_input: True
    normalize_value: True
    value_bootstrap: True
    num_actors: ${....task.env.numEnvs}
    reward_shaper:
      scale_value: 0.01
    normalize_advantage: True
    gamma: 0.99
    tau: 0.95
    learning_rate: 5e-4
    lr_schedule: adaptive
    kl_threshold: 0.008
    score_to_win: 20000
    max_epochs: ${resolve_default:1000,${....max_iterations}}
    save_best_after: 200
    save_frequency: 100
    print_stats: True
    grad_norm: 1.0
    entropy_coef: 0.0
    truncate_grads: True
    ppo: True
    e_clip: 0.2
    horizon_length: 32
    minibatch_size: 32768
    mini_epochs: 5
    critic_coef: 4
    clip_value: True
    seq_len: 4
    bounds_loss_coef: 0.0001

and urdf

<?xml version="1.0"?>
<robot name="bumpybot">

  <link name="base">
    <inertial>
      <origin rpy="0 0 0" xyz="0 0 0"/>
       <mass value="4.53592"/>
       <inertia ixx="1" ixy="0" ixz="0" iyy="1" iyz="0" izz="1"/>
    </inertial>
    <visual>
      <geometry>
        <cylinder length="0.635" radius="0.2794"/>
      </geometry>
      <material name="blue">
        <color rgba="0 0 1 1"/>
      </material>
    </visual>
    <collision>
      <origin xyz="0.0 0.0 0.0" rpy="0 0 0" />
      <geometry>
        <cylinder length="0.635" radius="0.2794" />
      </geometry>
      <sdf resolution="256"/>
    </collision>
  </link>

</robot>
3 Likes

I have the same issue with Isaacgym preview 4. I have two URDF files, A and B (I, unfortunately, cannot share the URDF files). The problem only exists when I use the URDF file A, which makes me wonder if its something to do with the URDF file that is causing NaN to appear when refresh_actor_root_state is called.

Any updates or solutions on how to fix this ??

I encountered the same issue. Has there been any solution to it so far?

The issue in the original code is because in __init__ , the line self._root_tensor = self.gym.acquire_actor_root_state_tensor(self.sim) is called before calling self.gym.refresh_actor_root_state_tensor(self.sim) , which initializes self._root_tensor to an all zeros tensor. This makes the quaternion in self.root_ori also all zeros, which is an invalid quaternion. When this is used to set the states in reset_idx, the invalid orientation will cause NaNs in physics. To fix the issue, we can add a call to self.gym.refresh_actor_root_state_tensor(self.sim) before acquiring the root states in __init__.

1 Like

I’m facing the exact same issue. Did you find out why this is happening?