COAT/coat-pvtv2-b2/models/resnet.py

54 lines
1.8 KiB
Python
Raw Normal View History

2023-10-10 21:52:30 +08:00
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
from collections import OrderedDict
import torch.nn.functional as F
import torchvision
from torch import nn
class Backbone(nn.Sequential):
def __init__(self, resnet):
super(Backbone, self).__init__(
OrderedDict(
[
["conv1", resnet.conv1],
["bn1", resnet.bn1],
["relu", resnet.relu],
["maxpool", resnet.maxpool],
["layer1", resnet.layer1], # res2
["layer2", resnet.layer2], # res3
["layer3", resnet.layer3], # res4
]
)
)
self.out_channels = 1024
def forward(self, x):
# using the forward method from nn.Sequential
feat = super(Backbone, self).forward(x)
return OrderedDict([["feat_res4", feat]])
class Res5Head(nn.Sequential):
def __init__(self, resnet):
super(Res5Head, self).__init__(OrderedDict([["layer4", resnet.layer4]])) # res5
self.out_channels = [1024, 2048]
def forward(self, x):
feat = super(Res5Head, self).forward(x)
x = F.adaptive_max_pool2d(x, 1)
feat = F.adaptive_max_pool2d(feat, 1)
return OrderedDict([["feat_res4", x], ["feat_res5", feat]])
def build_resnet(name="resnet50", pretrained=True):
resnet = torchvision.models.resnet.__dict__[name](pretrained=pretrained)
# freeze layers
resnet.conv1.weight.requires_grad_(False)
resnet.bn1.weight.requires_grad_(False)
resnet.bn1.bias.requires_grad_(False)
return Backbone(resnet), Res5Head(resnet)