72 lines
2.5 KiB
Python
72 lines
2.5 KiB
Python
# This file is part of COAT, and is distributed under the
|
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
|
|
|
from collections import OrderedDict
|
|
from backbone.pvt_v2 import pvt_v2_b2_2,pvt_v2_b2
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
from torch import nn
|
|
|
|
class Backbone(nn.Sequential):
|
|
def __init__(self, resnet):
|
|
super(Backbone, self).__init__(
|
|
OrderedDict(
|
|
[
|
|
["conv1", resnet.conv1],
|
|
["bn1", resnet.bn1],
|
|
["relu", resnet.relu],
|
|
["maxpool", resnet.maxpool],
|
|
["layer1", resnet.layer1], # res2
|
|
["layer2", resnet.layer2], # res3
|
|
["layer3", resnet.layer3], # res4
|
|
]
|
|
)
|
|
)
|
|
self.out_channels = 1024
|
|
self.out_feat_key = "feat_res4"
|
|
|
|
def forward(self, x):
|
|
# using the forward method from nn.Sequential
|
|
feat = super(Backbone, self).forward(x)
|
|
return OrderedDict([[self.out_feat_key, feat]])
|
|
|
|
|
|
class PVTv2Backbone(pvt_v2_b2):
|
|
def __init__(self, pretrained_path=""):
|
|
super(PVTv2Backbone, self).__init__(pretrained = pretrained_path)
|
|
self.out_channels = 512
|
|
self.out_feat_key = "feat_pvtv23"
|
|
def forward(self, x):
|
|
feat = super(PVTv2Backbone, self).forward(x)
|
|
return OrderedDict([[self.out_feat_key, feat[3]]])
|
|
|
|
class Res5Head(nn.Sequential):
|
|
def __init__(self, resnet):
|
|
super(Res5Head, self).__init__(OrderedDict([["layer4", resnet.layer4]])) # res5
|
|
self.out_channels = [1024, 2048]
|
|
|
|
def forward(self, x):
|
|
feat = super(Res5Head, self).forward(x)
|
|
x = F.adaptive_max_pool2d(x, 1)
|
|
feat = F.adaptive_max_pool2d(feat, 1)
|
|
return OrderedDict([["feat_res4", x], ["feat_res5", feat]])
|
|
|
|
|
|
def build_resnet(name="resnet50", pretrained=True):
|
|
resnet = torchvision.models.resnet.__dict__[name](pretrained=pretrained)
|
|
# freeze layers
|
|
resnet.conv1.weight.requires_grad_(False)
|
|
resnet.bn1.weight.requires_grad_(False)
|
|
resnet.bn1.bias.requires_grad_(False)
|
|
return Backbone(resnet)
|
|
|
|
|
|
def build_network(name="resnet50", pretrained=True):
|
|
if(name == "resnet50"):
|
|
return build_resnet(name, pretrained)
|
|
else:
|
|
# use pvtv2_b2 网络
|
|
#model = pvt_v2_b2_2(pretrained = "./backbone/pvt_v2_b2.pth")
|
|
model = PVTv2Backbone(pretrained_path = "./backbone/pvt_v2_b2.pth")
|
|
return model |