How to transfer LSTM caffemodel to TensorRT weights

The the formats of LSTM’s parameters are different between Caffe and TensorRT. Therefore, To use the LSTM’s weights from caffemodel, we should change the format of them and place these weights to TensorRT layer when we construct TensorRT engines.

Follow the steps to transfer it:

  1. the LSTM format in caffe model is
    [W_xi, W_xf, W_xo, W_xc]
    [B_i, B_f, B_o, B_c]
    [W_hi, W_hf, W_ho, W_hc],

while the LSTM format in TRT model is
[W_xf, W_xi, W_xc, W_xo, W_hf, W_hi, W_hc, W_ho]
[B_xf, B_xi, B_xc, B_xo, B_hf, B_hi, B_hc, B_ho]

PS: “W” means the weight and “B” means the bias. “x” means the input (e.g image input). “h” means hidden state. “i” means the input gate .“f” means the forget gate. “o” means the output gate.“c” means the cell gate.

Here is the function I used for transferring:

//C++
//caffe format [W_xi, W_xf, W_xo, W_xc] [B_i, B_f, B_o, B_c] [W_hi, W_hf, W_ho, W_hc]
//TRT format [W_xf, W_xi, W_xc, W_xo, W_hf, W_hi, W_hc, W_ho]
//           [B_xf, B_xi, B_xc, B_xo, B_hf, B_hi, B_hc, B_ho]
void convertCaffe2TRT_LSTM(shared_ptr<Net<float> > lstm, string layer_name,
                          Weights& weights, Weights& bias, int data_size,
                          int hidden_size){
  const shared_ptr<Layer<float> > lstm_layer = lstm->layer_by_name(layer_name);
  vector<shared_ptr<Blob<float> > >& caffe_weights = lstm_layer->blobs();
 
  long weight_cnt = hidden_size * data_size * 4 + hidden_size * hidden_size * 4;
  long bias_cnt = hidden_size * 8 ;
  float* weight_val = reinterpret_cast<float*>(malloc(sizeof(float)*weight_cnt));
  float* bias_val = reinterpret_cast<float*>(malloc(sizeof(float)*bias_cnt));
  memset(bias_val, 0, sizeof(float)*bias_cnt);
  weights.values = weight_val;
  weights.count = weight_cnt;
  bias.values = bias_val;
  bias.count = bias_cnt;
 
  //Convert Caffe LSTM to TRT
  long x_gate_weight_num = data_size * hidden_size;
  long h_gate_weight_num = hidden_size * hidden_size;
  //x forget gate
  memcpy(weight_val, caffe_weights[0]->cpu_data()+ x_gate_weight_num,
        sizeof(float)*x_gate_weight_num);
  weight_val += x_gate_weight_num;
  //x input gate
  memcpy(weight_val, caffe_weights[0]->cpu_data(),
        sizeof(float)*x_gate_weight_num);
  weight_val += x_gate_weight_num;
  //x cell gate
  memcpy(weight_val, caffe_weights[0]->cpu_data() + 3*x_gate_weight_num,
        sizeof(float)*x_gate_weight_num);
  weight_val += x_gate_weight_num;
  //x output gate
  memcpy(weight_val, caffe_weights[0]->cpu_data() + 2*x_gate_weight_num,
        sizeof(float)*x_gate_weight_num);
  weight_val += x_gate_weight_num;
 
  //h forget gate
  memcpy(weight_val, caffe_weights[2]->cpu_data() + h_gate_weight_num,
        sizeof(float)*h_gate_weight_num);
  weight_val += h_gate_weight_num;
  //h input gate
  memcpy(weight_val, caffe_weights[2]->cpu_data(),
        sizeof(float)*h_gate_weight_num);
  weight_val += h_gate_weight_num;
  //h cell gate
  memcpy(weight_val, caffe_weights[2]->cpu_data() + 3*h_gate_weight_num,
        sizeof(float)*h_gate_weight_num);
  weight_val += h_gate_weight_num;
  //h output gate
  memcpy(weight_val, caffe_weights[2]->cpu_data() + 2*h_gate_weight_num,
        sizeof(float)*h_gate_weight_num);
  weight_val += h_gate_weight_num;
 
  //forget bias
  memcpy(bias_val, caffe_weights[1]->cpu_data() + hidden_size,
        sizeof(float)*hidden_size);
  bias_val += hidden_size;
  //input bias
  memcpy(bias_val, caffe_weights[1]->cpu_data(),
        sizeof(float)*hidden_size);
  bias_val += hidden_size;
  //cell bias
  memcpy(bias_val, caffe_weights[1]->cpu_data() + 3 * hidden_size,
        sizeof(float)*hidden_size);
  bias_val += hidden_size;
  //output bias
  memcpy(bias_val, caffe_weights[1]->cpu_data() + 2 * hidden_size,
        sizeof(float)*hidden_size);
  bias_val += hidden_size;
  //skip 4 set of bias of hidden size 
  bias_val += 4*hidden_size;
}
  1. Set two Weights object (lstm_weights and lstm_bias) to TRT LSTM layer and initialize the cell input and hidden input with constant 0.
//C++
auto lstm = network->addRNN(*dataIn, 1, HIDDEN_SIZE, SEQ_SIZE,
          RNNOperation::kLSTM, RNNInputMode::kLINEAR, RNNDirection::kUNIDIRECTION,
          lstm_weights, lstm_bias);
lstm->setHiddenState(*hiddenIn);
lstm->setCellState(*cellIn);