I’m trying to write a custom PyTorch FasterRCNN that’s based on an existing PyTorch model, fasterrcnn_resnet50_fpn
, but I’m getting stuck on how to correctly write the forward pass. Following this recommendation, I’m basing the forward pass off of the GeneralizedRCNN
forward pass, because I’d like to eventually modify the loss function. However, I’m getting errors on the forward pass formulation without loss function modifications, and not sure how to proceed.
Here is my code:
class CustomFasterRCNNResNet50FPN(nn.Module):
def __init__(self, num_classes, **kwargs):
super().__init__()
# Load the pre-trained fasterrcnn_resnet50_fpn model
self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
# Replace the classifier with a new one for the desired number of classes
in_features = self.model.roi_heads.box_predictor.cls_score.in_features
self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
def forward(self, images, targets=None):
# Convert the list of images to a tensor
images = torch.stack(images)
# Create an ImageList object from the images tensor
image_list = ImageList(images, [(img.shape[-2], img.shape[-1]) for img in images])
# Pass the images through the model
features = self.model.backbone(images)
if isinstance(features, torch.Tensor):
features = OrderedDict([("0", features)])
proposals, proposal_losses = self.model.rpn(image_list, features, targets)
detections, detector_losses = self.model.roi_heads(features, proposals, image_list.image_sizes, targets)
print("Detections:", detections)
print("Detector Losses:", detector_losses)
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
return losses
When I run this, I get a printout that shows detections
is an empty list, which it shouldn’t be. I know this isn’t my images or my targets, because I can run them directly through an unmodified fasterrcnn_resnet50_fpn
and I get detections. I also know it’s not the initialization, because I can test that using a simpler forward pass, which also works but wouldn’t allow me to eventually modify the loss function.
Thank you in advance for any and all help with this!