From 6390b40be3deb2789d552575f142da46950b3dc8 Mon Sep 17 00:00:00 2001 From: "xiangyu.xie" Date: Thu, 28 May 2026 16:33:35 +0200 Subject: [PATCH] Add new single photon model --- src/models.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/models.py b/src/models.py index 62a4b48..ac4b81d 100644 --- a/src/models.py +++ b/src/models.py @@ -102,6 +102,44 @@ class singlePhotonNet_251022(nn.Module): x = self.fc(x) return x +import torch +import torch.nn as nn +import torch.nn.functional as F + +class singlePhotonNet_260511(nn.Module): + def __init__(self): + super(singlePhotonNet_260511, self).__init__() + + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(32, 64, kernel_size=3) + + self.fc = nn.Sequential( + nn.Linear(64, 2), + ) + + self.weight_init() + + def weight_init(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + c1 = F.relu(self.conv1(x)) # [B, 16, 3, 3] + c2 = F.relu(self.conv2(c1)) # [B, 32, 3, 3] + c3 = F.relu(self.conv3(c2)) # [B, 64, 1, 1] + + flat_feat = c3.view(c3.size(0), -1) # [B, 64] + + coords = self.fc(flat_feat) # [B, 2] + return coords + class doublePhotonNet_250909(nn.Module): def __init__(self): super(doublePhotonNet_250909, self).__init__()