diff --git a/src/models.py b/src/models.py index caf7783..5c7771a 100644 --- a/src/models.py +++ b/src/models.py @@ -29,6 +29,35 @@ class singlePhotonNet_250909(nn.Module): x = self.fc(x) return x +class singlePhotonNet_251020(nn.Module): + ''' + Smaller input size (3x3) + ''' + def weight_init(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + def __init__(self): + super(singlePhotonNet_251020, self).__init__() + self.conv1 = nn.Conv2d(1, 5, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(5, 10, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(10, 20, kernel_size=3, padding=1) + self.fc = nn.Linear(20*3*3, 2) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + class doublePhotonNet_250909(nn.Module): def __init__(self): super(doublePhotonNet_250909, self).__init__()