I try to export PyTorch model that uses PackedSequence into ONNX model and get next error
RuntimeError: ONNX export failed: Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.
Sorry for a maybe silly question, but how can I use these operations in pairs?
Code of model:
bn_output_features = nn.utils.rnn.pack_padded_sequence(output_features.to(self.device), origin_nrof_trans, enforce_sorted=False, batch_first=True,)
bn_output_features = self.feature_bn(bn_output_features) # There will new PackedSequence from function below
bn_output_features, _ = nn.utils.rnn.pad_packed_sequence(bn_output_features, batch_first=True)
def forward(self, x):
return torch.nn.utils.rnn.PackedSequence(data=super().forward(x.data), batch_sizes=x.batch_sizes, sorted_indices=x.sorted_indices,)
Code for export:
torch.onnx.export(model=self.base_rnn,
args=(features, nrof_trans),
f="ct4001.onnx",
export_params=True,
opset_version=10,
do_constant_folding=True,
input_names = ['input_data', 'origin_nrof_trans'],
output_names = ['output'],
# dynamic_axes={
# 'input' : {0 : 'batch_size'},
# 'output' : {0 : 'batch_size'}
# }
)
I tried to decorate pad_packed_sequence, to convert PackedSequence right away, but it doesn’t work
def pad_sequence(f, *args, batch_first=False):
tensor, batch_sz, sorted_inds = None, None, None
def wrapper():
nonlocal tensor, batch_sz, sorted_inds
tensor = args[0].data
batch_sz = args[0].batch_sizes.
sorted_inds = args[0].sorted_indices
return f(*args, batch_first=batch_first)
return wrapper()[0], tensor, batch_sz, sorted_inds
data, tensor, batch_sz, sorted_inds = pad_sequence(
nn.utils.rnn.pad_packed_sequence,
nn.utils.rnn.pack_padded_sequence(
output_features.to(self.device),
origin_nrof_trans,
enforce_sorted=False, batch_first=True,)
)
bn_output_features = self.feature_bn(tensor, batch_sz, sorted_inds) # Function below return tensor
def forward(self, data, batch_sz, sorted_inds):
return pad_sequence(
nn.utils.rnn.pad_packed_sequence,
nn.utils.rnn.PackedSequence(
data=super().forward(data),
batch_sizes=batch_sz,
sorted_indices=sorted_inds,),
batch_first=True
)[0]