From dd8751466ab456b5794ea22a3e2dd23fef1d96f9 Mon Sep 17 00:00:00 2001 From: siyli Date: Tue, 4 Jan 2022 13:13:55 +0100 Subject: [PATCH] fix unused parameters --- dla.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dla.py b/dla.py index d9f3007..d761968 100644 --- a/dla.py +++ b/dla.py @@ -196,7 +196,7 @@ def __init__(self, levels, block, in_channels, out_channels, stride=1, self.levels = levels if stride > 1: self.downsample = nn.MaxPool2d(stride, stride=stride) - if in_channels != out_channels: + if in_channels != out_channels and levels != 1: self.project = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), @@ -245,10 +245,10 @@ def __init__(self, levels, channels, num_classes=1000, level_root=True, root_residual=residual_root) self.level5 = Tree(levels[5], block, channels[4], channels[5], 2, level_root=True, root_residual=residual_root) - - self.avgpool = nn.AvgPool2d(pool_size) - self.fc = nn.Conv2d(channels[-1], num_classes, kernel_size=1, - stride=1, padding=0, bias=True) + if not self.return_levels: + self.avgpool = nn.AvgPool2d(pool_size) + self.fc = nn.Conv2d(channels[-1], num_classes, kernel_size=1, + stride=1, padding=0, bias=True) for m in self.modules(): if isinstance(m, nn.Conv2d):