bias = x
x = RFFT2(x)
x = x.reshape(b, h, w//2+1, k, d/k)
x = BlockMLP(x)
x = x.reshape(b, h, w//2+1, d)
x = SoftShrink(x)
x = IRFFT2(x)
return x + bias
such that the MLP happens between the FFT and IFFT.
However, in the codebase, it’s
residual = x
x = self.norm1(x)
x = self.filter(x)
if self.double_skip:
x = x + residual
residual = x
x = self.norm2(x)
x = self.mlp(x)
x = self.drop_path(x)
x = x + residual
Such that the MLP happens after the IFFT. This seems like a pretty important difference. Am I missing something?
Thanks for your interest. The section of code you’re looking at is a block of AFNO inside FourcastNet which is what is implemented inside of Modulus as well as the FCN repo you linked.
To see the Fourier convolution, we need to look at the AFNO2D module. Inside the forward of this torch module, you will see the FFT transforms. In between the rfft2 and irfft2 theres some reshapes as well as einsum ops which look like the following:
There’s four of then (one for each corner of the spectral coefficient matrix). These are explicitly performing a single MLP layer on the spectral coefficients (both the real and imaginary parts), note the self.w2 and self.b2 terms.
So as I understand it, the operations with einsum only perform linear operations, as there is no nonlinearity in between the matrix multiplications. The AFNO paper however, proposes to use an MLP, such that
def BlockMLP(x):
x = MatMul(x, W_1) + b_1
x = ReLU(x)
return MatMul(x, W_2) + b_2
As you can see however, in the code snippet from my first message, the AFNO implementation uses the MLP only after the inverse fourier transform. Therefore I feel like there is a missing non-linearity in the AFNO2D module in the computation of the o2 quantities. What’s more, there shouldn’t be the MLP applied after the AFNO2D module.
@alexandre.szenicer Thanks for looking into this. There are a couple of figures in the paper. Attaching the most relevant part here. There is the MLP between the FFT and IFFT which is the block diagonal matrix multiply as you have in the first code snipper of your original post. However, there is also another MLP after the IFFT that is not block diagonal and that’s what is in self.mlp in the second code snippet that you have. Hope this helps clear up some of the confusion.
Would you be interested in speaking to the modulus team to share more details on your use case. You can reach out to us directly at modulus-team@exchange.nvidia.com.