# This file is part of COAT, and is distributed under the # OSI-approved BSD 3-Clause License. See top-level LICENSE file or # https://github.com/Kitware/COAT/blob/master/LICENSE for details. from collections import OrderedDict import torch.nn.functional as F import torchvision from torch import nn class Backbone(nn.Sequential): def __init__(self, resnet): super(Backbone, self).__init__( OrderedDict( [ ["conv1", resnet.conv1], ["bn1", resnet.bn1], ["relu", resnet.relu], ["maxpool", resnet.maxpool], ["layer1", resnet.layer1], # res2 ["layer2", resnet.layer2], # res3 ["layer3", resnet.layer3], # res4 ] ) ) self.out_channels = 1024 def forward(self, x): # using the forward method from nn.Sequential feat = super(Backbone, self).forward(x) return OrderedDict([["feat_res4", feat]]) class Res5Head(nn.Sequential): def __init__(self, resnet): super(Res5Head, self).__init__(OrderedDict([["layer4", resnet.layer4]])) # res5 self.out_channels = [1024, 2048] def forward(self, x): feat = super(Res5Head, self).forward(x) x = F.adaptive_max_pool2d(x, 1) feat = F.adaptive_max_pool2d(feat, 1) return OrderedDict([["feat_res4", x], ["feat_res5", feat]]) def build_resnet(name="resnet50", pretrained=True): resnet = torchvision.models.resnet.__dict__[name](pretrained=pretrained) # freeze layers resnet.conv1.weight.requires_grad_(False) resnet.bn1.weight.requires_grad_(False) resnet.bn1.bias.requires_grad_(False) return Backbone(resnet), Res5Head(resnet)