54 lines
1.8 KiB
Python
54 lines
1.8 KiB
Python
|
# This file is part of COAT, and is distributed under the
|
||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||
|
|
||
|
from collections import OrderedDict
|
||
|
import torch.nn.functional as F
|
||
|
import torchvision
|
||
|
from torch import nn
|
||
|
|
||
|
class Backbone(nn.Sequential):
|
||
|
def __init__(self, resnet):
|
||
|
super(Backbone, self).__init__(
|
||
|
OrderedDict(
|
||
|
[
|
||
|
["conv1", resnet.conv1],
|
||
|
["bn1", resnet.bn1],
|
||
|
["relu", resnet.relu],
|
||
|
["maxpool", resnet.maxpool],
|
||
|
["layer1", resnet.layer1], # res2
|
||
|
["layer2", resnet.layer2], # res3
|
||
|
["layer3", resnet.layer3], # res4
|
||
|
]
|
||
|
)
|
||
|
)
|
||
|
self.out_channels = 1024
|
||
|
|
||
|
def forward(self, x):
|
||
|
# using the forward method from nn.Sequential
|
||
|
feat = super(Backbone, self).forward(x)
|
||
|
return OrderedDict([["feat_res4", feat]])
|
||
|
|
||
|
|
||
|
class Res5Head(nn.Sequential):
|
||
|
def __init__(self, resnet):
|
||
|
super(Res5Head, self).__init__(OrderedDict([["layer4", resnet.layer4]])) # res5
|
||
|
self.out_channels = [1024, 2048]
|
||
|
|
||
|
def forward(self, x):
|
||
|
feat = super(Res5Head, self).forward(x)
|
||
|
x = F.adaptive_max_pool2d(x, 1)
|
||
|
feat = F.adaptive_max_pool2d(feat, 1)
|
||
|
return OrderedDict([["feat_res4", x], ["feat_res5", feat]])
|
||
|
|
||
|
|
||
|
def build_resnet(name="resnet50", pretrained=True):
|
||
|
resnet = torchvision.models.resnet.__dict__[name](pretrained=pretrained)
|
||
|
|
||
|
# freeze layers
|
||
|
resnet.conv1.weight.requires_grad_(False)
|
||
|
resnet.bn1.weight.requires_grad_(False)
|
||
|
resnet.bn1.bias.requires_grad_(False)
|
||
|
|
||
|
return Backbone(resnet), Res5Head(resnet)
|