29 lines
1.1 KiB
Python
29 lines
1.1 KiB
Python
import torch
|
||
import torchvision
|
||
from collections import OrderedDict
|
||
|
||
# featmap_names (List[str]): the names of the feature maps that will be used
|
||
# for the pooling.
|
||
# output_size (List[Tuple[int, int]] or List[int]): output size for the pooled region
|
||
# sampling_ratio (int): sampling ratio for ROIAlign
|
||
|
||
# canonical_scale (int, optional): canonical_scale for LevelMapper
|
||
# canonical_level (int, optional): canonical_level for LevelMapper
|
||
# 依次是要处理的特征图名字、输出尺寸、采样系数
|
||
roi = torchvision.ops.MultiScaleRoIAlign(['feat1', 'feat3'], 5, 2)
|
||
i = OrderedDict()
|
||
# 构建仿真的特征
|
||
i['feat1'] = torch.rand(1, 5, 64, 64)
|
||
# this feature won't be used in the pooling
|
||
i['feat2'] = torch.rand(1, 5, 32, 32)
|
||
i['feat3'] = torch.rand(1, 5, 16, 16)
|
||
# 创建随机的矩形框
|
||
boxes = torch.rand(6, 4) * 256; boxes[:, 2:] += boxes[:, :2]
|
||
# original image size, before computing the feature maps
|
||
image_sizes = [(512, 512)]
|
||
output = roi(i, [boxes], image_sizes)
|
||
print(output.shape)
|
||
#print(output)
|
||
|
||
# 6个矩形框、5个通道、3x3是怎么来的?
|
||
# torch.Size([6, 5, 3, 3]) |