FourCastNet implementation different from paper

Hi,

I am interested in using FourCastNet, and I’m going through the source code on github GitHub - NVlabs/FourCastNet: Initial public release of code, data, and model weights for FourCastNet , assuming it’s the same implementation in Modulus.

In the paper, the AFNO layer is defined such that

bias = x
x = RFFT2(x)
x = x.reshape(b, h, w//2+1, k, d/k)
x = BlockMLP(x)
x = x.reshape(b, h, w//2+1, d)
x = SoftShrink(x)
x = IRFFT2(x)
return x + bias

such that the MLP happens between the FFT and IFFT.
However, in the codebase, it’s

  residual = x
  x = self.norm1(x)
  x = self.filter(x) 

  if self.double_skip:
      x = x + residual
      residual = x

  x = self.norm2(x)
  x = self.mlp(x)
  x = self.drop_path(x)
  x = x + residual 

Such that the MLP happens after the IFFT. This seems like a pretty important difference. Am I missing something?

Hi @alexandre.szenicer

Thanks for your interest. The section of code you’re looking at is a block of AFNO inside FourcastNet which is what is implemented inside of Modulus as well as the FCN repo you linked.

To see the Fourier convolution, we need to look at the AFNO2D module. Inside the forward of this torch module, you will see the FFT transforms. In between the rfft2 and irfft2 theres some reshapes as well as einsum ops which look like the following:

o2_imag[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes] = (
    torch.einsum(
        "...bi,bio->...bo",
        o1_imag[
            :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes
        ],
        self.w2[0],
    )
    + torch.einsum(
        "...bi,bio->...bo",
        o1_real[
            :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes
        ],
        self.w2[1],
    )
    + self.b2[1]
)

There’s four of then (one for each corner of the spectral coefficient matrix). These are explicitly performing a single MLP layer on the spectral coefficients (both the real and imaginary parts), note the self.w2 and self.b2 terms.

Does this clear things up?

Hi @ngeneva,

Thanks for your quick response!

So as I understand it, the operations with einsum only perform linear operations, as there is no nonlinearity in between the matrix multiplications. The AFNO paper however, proposes to use an MLP, such that

def BlockMLP(x):
x = MatMul(x, W_1) + b_1
x = ReLU(x)
return MatMul(x, W_2) + b_2

As you can see however, in the code snippet from my first message, the AFNO implementation uses the MLP only after the inverse fourier transform. Therefore I feel like there is a missing non-linearity in the AFNO2D module in the computation of the o2 quantities. What’s more, there shouldn’t be the MLP applied after the AFNO2D module.

Oh I’m sorry I realised I just missed the Relu after the computation of the o1 quantities. It’s my bad!

1 Like

@alexandre.szenicer Thanks for looking into this. There are a couple of figures in the paper. Attaching the most relevant part here. There is the MLP between the FFT and IFFT which is the block diagonal matrix multiply as you have in the first code snipper of your original post. However, there is also another MLP after the IFFT that is not block diagonal and that’s what is in self.mlp in the second code snippet that you have. Hope this helps clear up some of the confusion.

Hi @alexandre.szenicer

Would you be interested in speaking to the modulus team to share more details on your use case. You can reach out to us directly at modulus-team@exchange.nvidia.com.

Thanks
Ram