Thanks for your reply.
I find the performance is bad.
I add the timer as the below:
For IS_FORWARD == 1
double start = seconds();
cudnnMultiHeadAttnForward
cudaDeviceSynchronize();
double stop = seconds();
For backward, IS_FORWARD == 0
double start = seconds();
cudnnMultiHeadAttnForward
cudnnMultiHeadAttnBackwardData
cudnnMultiHeadAttnBackwardWeights
cudaDeviceSynchronize();
double stop = seconds();
duration = stop - start;
The result:
IS_FORWARD 1, Elapsed time = 0.000691891 sec
IS_FORWARD 0, Elapsed time = 0.262995 sec
IS_FORWARD 1, Elapsed time = 0.000668049 sec
IS_FORWARD 0, Elapsed time = 0.262635 sec
IS_FORWARD 1, Elapsed time = 0.000663996 sec
IS_FORWARD 0, Elapsed time = 0.263343 sec
IS_FORWARD 1, Elapsed time = 0.000674963 sec
IS_FORWARD 0, Elapsed time = 0.263066 sec
IS_FORWARD 1, Elapsed time = 0.000671148 sec
IS_FORWARD 0, Elapsed time = 0.262862 sec
IS_FORWARD 1, Elapsed time = 0.000673056 sec
IS_FORWARD 0, Elapsed time = 0.262764 sec
IS_FORWARD 1, Elapsed time = 0.000664949 sec
IS_FORWARD 0, Elapsed time = 0.262851 sec
IS_FORWARD 1, Elapsed time = 0.000669003 sec
IS_FORWARD 0, Elapsed time = 0.262641 sec
IS_FORWARD 1, Elapsed time = 0.000679016 sec
IS_FORWARD 0, Elapsed time = 0.262704 sec
IS_FORWARD 1, Elapsed time = 0.000658989 sec
IS_FORWARD 0, Elapsed time = 0.262794 sec
IS_FORWARD 1, Elapsed time = 0.000673056 sec
IS_FORWARD 0, Elapsed time = 0.263157 sec
IS_FORWARD 1, Elapsed time = 0.000689983 sec
IS_FORWARD 0, Elapsed time = 0.262993 sec
IS_FORWARD 1, Elapsed time = 0.000663996 sec
IS_FORWARD 0, Elapsed time = 0.262549 sec
The duration of backward is beyond 260ms.
Config:
####attnDataType = 0 (FP32)
#### attnNumHeads = 16
#### attnBatchSize = 1
#### attnBeamSize = 1
#### attnSmScaler = 1.0000e+00
#### attnDropoutRate = 0.0000
#### attnQsize = 1024
#### attnKsize = 1024
#### attnVsize = 1024
#### attnProjQsize = 64
#### attnProjKsize = 64
#### attnProjVsize = 64
#### attnProjOsize = 1024
#### attnSeqLenQ = 384
#### attnSeqLenK = 384
#### attnDataLayout = 0 (T,N,B,V)
#### attnResLink = 0
#### attnSweep = 0
#### attnRandGeom = 0
#### attnRandSeed = 1234
#### attnFileDump = 0
Any suggestion to improve the performance?