# This file is part of COAT, and is distributed under the # OSI-approved BSD 3-Clause License. See top-level LICENSE file or # https://github.com/Kitware/COAT/blob/master/LICENSE for details. from collections import OrderedDict from backbone.pvt_v2 import pvt_v2_b2_2,pvt_v2_b2 import torch.nn.functional as F import torchvision from torch import nn class Backbone(nn.Sequential): def __init__(self, resnet): super(Backbone, self).__init__( OrderedDict( [ ["conv1", resnet.conv1], ["bn1", resnet.bn1], ["relu", resnet.relu], ["maxpool", resnet.maxpool], ["layer1", resnet.layer1], # res2 ["layer2", resnet.layer2], # res3 ["layer3", resnet.layer3], # res4 ] ) ) self.out_channels = 1024 self.out_feat_key = "feat_res4" def forward(self, x): # using the forward method from nn.Sequential feat = super(Backbone, self).forward(x) return OrderedDict([[self.out_feat_key, feat]]) class PVTv2Backbone(pvt_v2_b2): def __init__(self, pretrained_path=""): super(PVTv2Backbone, self).__init__(pretrained = pretrained_path) self.out_channels = 512 self.out_feat_key = "feat_pvtv23" def forward(self, x): feat = super(PVTv2Backbone, self).forward(x) return OrderedDict([[self.out_feat_key, feat[3]]]) class Res5Head(nn.Sequential): def __init__(self, resnet): super(Res5Head, self).__init__(OrderedDict([["layer4", resnet.layer4]])) # res5 self.out_channels = [1024, 2048] def forward(self, x): feat = super(Res5Head, self).forward(x) x = F.adaptive_max_pool2d(x, 1) feat = F.adaptive_max_pool2d(feat, 1) return OrderedDict([["feat_res4", x], ["feat_res5", feat]]) def build_resnet(name="resnet50", pretrained=True): resnet = torchvision.models.resnet.__dict__[name](pretrained=pretrained) # freeze layers resnet.conv1.weight.requires_grad_(False) resnet.bn1.weight.requires_grad_(False) resnet.bn1.bias.requires_grad_(False) return Backbone(resnet) def build_network(name="resnet50", pretrained=True): if(name == "resnet50"): return build_resnet(name, pretrained) else: # use pvtv2_b2 网络 #model = pvt_v2_b2_2(pretrained = "./backbone/pvt_v2_b2.pth") model = PVTv2Backbone(pretrained_path = "./backbone/pvt_v2_b2.pth") return model