Example-Research-COAT/Code/Python/models/resnet.py

72 lines
2.5 KiB
Python

# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
from collections import OrderedDict
from backbone.pvt_v2 import pvt_v2_b2_2,pvt_v2_b2
import torch.nn.functional as F
import torchvision
from torch import nn
class Backbone(nn.Sequential):
def __init__(self, resnet):
super(Backbone, self).__init__(
OrderedDict(
[
["conv1", resnet.conv1],
["bn1", resnet.bn1],
["relu", resnet.relu],
["maxpool", resnet.maxpool],
["layer1", resnet.layer1], # res2
["layer2", resnet.layer2], # res3
["layer3", resnet.layer3], # res4
]
)
)
self.out_channels = 1024
self.out_feat_key = "feat_res4"
def forward(self, x):
# using the forward method from nn.Sequential
feat = super(Backbone, self).forward(x)
return OrderedDict([[self.out_feat_key, feat]])
class PVTv2Backbone(pvt_v2_b2):
def __init__(self, pretrained_path=""):
super(PVTv2Backbone, self).__init__(pretrained = pretrained_path)
self.out_channels = 512
self.out_feat_key = "feat_pvtv23"
def forward(self, x):
feat = super(PVTv2Backbone, self).forward(x)
return OrderedDict([[self.out_feat_key, feat[3]]])
class Res5Head(nn.Sequential):
def __init__(self, resnet):
super(Res5Head, self).__init__(OrderedDict([["layer4", resnet.layer4]])) # res5
self.out_channels = [1024, 2048]
def forward(self, x):
feat = super(Res5Head, self).forward(x)
x = F.adaptive_max_pool2d(x, 1)
feat = F.adaptive_max_pool2d(feat, 1)
return OrderedDict([["feat_res4", x], ["feat_res5", feat]])
def build_resnet(name="resnet50", pretrained=True):
resnet = torchvision.models.resnet.__dict__[name](pretrained=pretrained)
# freeze layers
resnet.conv1.weight.requires_grad_(False)
resnet.bn1.weight.requires_grad_(False)
resnet.bn1.bias.requires_grad_(False)
return Backbone(resnet)
def build_network(name="resnet50", pretrained=True):
if(name == "resnet50"):
return build_resnet(name, pretrained)
else:
# use pvtv2_b2 网络
#model = pvt_v2_b2_2(pretrained = "./backbone/pvt_v2_b2.pth")
model = PVTv2Backbone(pretrained_path = "./backbone/pvt_v2_b2.pth")
return model