Dynamic loading model

This commit is contained in:
2025-10-22 09:24:49 +02:00
parent 01fd862833
commit 66e7792ce9
3 changed files with 46 additions and 9 deletions
+37
View File
@@ -3,6 +3,13 @@ import torch.nn.functional as F
import torch
import numpy as np
def get_model_class(version):
class_name = f'singlePhotonNet_{version}'
cls = globals().get(class_name)
if cls is None:
raise ValueError(f"Model class '{class_name}' not found.")
return cls
class singlePhotonNet_250909(nn.Module):
def weight_init(self):
for m in self.modules():
@@ -58,6 +65,36 @@ class singlePhotonNet_251020(nn.Module):
x = self.fc(x)
return x
class singlePhotonNet_251022(nn.Module):
'''
Smaller input size (3x3)
'''
def weight_init(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def __init__(self):
super(singlePhotonNet_251022, self).__init__()
self.conv1 = nn.Conv2d(1, 5, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(5, 10, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(10, 20, kernel_size=3, padding=1)
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(20, 2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = self.gap(x).view(x.size(0), -1)
x = self.fc(x)
return x
class doublePhotonNet_250909(nn.Module):
def __init__(self):
super(doublePhotonNet_250909, self).__init__()