Dynamic loading model
This commit is contained in:
@@ -1,15 +1,15 @@
|
||||
import torch
|
||||
import sys
|
||||
sys.path.append('./src')
|
||||
from models import *
|
||||
import models
|
||||
from datasets import *
|
||||
from tqdm import tqdm
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
configs = {}
|
||||
configs['SiemenStar'] = {
|
||||
'dataFiles': [f'/home/xie_x1/MLXID/DeepLearning/SiemenStarClusters/clusters_chunk{i}.h5' for i in range(100)],
|
||||
'modelVersion': 'singlePhotonNet_251020',
|
||||
'dataFiles': [f'/mnt/sls_det_storage/moench_data/MLXID/Samples/Measurement/2504_SOLEIL_SiemenStarClusters_MOENCH040_150V/clusters_chunk{i}.h5' for i in range(32)],
|
||||
'modelVersion': '251020',
|
||||
'roi': [140, 230, 120, 210], # x_min, x_max, y_min, y_max,
|
||||
'noise': 0.13 # in keV
|
||||
}
|
||||
@@ -24,9 +24,8 @@ if __name__ == "__main__":
|
||||
task = 'SiemenStar'
|
||||
config = configs[task]
|
||||
|
||||
if config['modelVersion'] == 'singlePhotonNet_251020':
|
||||
model = singlePhotonNet_251020().cuda()
|
||||
model.load_state_dict(torch.load(f'/home/xie_x1/MLXID/DeepLearning/Models/singlePhotonNet_Noise{config["noise"]}keV_251020.pth', weights_only=True))
|
||||
model = models.get_model_class(config['modelVersion'])().cuda()
|
||||
model.load_state_dict(torch.load(f'/home/xie_x1/MLXID/DeepLearning/Models/singlePhotonNet_Noise{config["noise"]}keV_{config["modelVersion"]}.pth', weights_only=True))
|
||||
|
||||
dataset = singlePhotonDataset(config['dataFiles'], sampleRatio=1.0, datasetName='Inference')
|
||||
dataLoader = torch.utils.data.DataLoader(
|
||||
|
||||
@@ -2,7 +2,7 @@ import sys
|
||||
sys.path.append('./src')
|
||||
import torch
|
||||
import numpy as np
|
||||
from models import *
|
||||
import models
|
||||
from datasets import *
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
@@ -12,7 +12,7 @@ from torchinfo import summary
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
modelVersion = '251020' # '250909' or '251020'
|
||||
modelVersion = '251022' # '250909' or '251020'
|
||||
TrainLosses, ValLosses = [], []
|
||||
TestLoss = -1
|
||||
Noise = 0.13 # in keV
|
||||
@@ -58,8 +58,9 @@ def test(model, testLoader):
|
||||
TestLoss = avgLoss
|
||||
return avgLoss
|
||||
|
||||
model = singlePhotonNet_251020().cuda()
|
||||
model = models.get_model_class(modelVersion)().cuda()
|
||||
# summary(model, input_size=(128, 1, 3, 3))
|
||||
|
||||
sampleFolder = '/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040'
|
||||
trainDataset = singlePhotonDataset(
|
||||
[f'{sampleFolder}/15keV_Moench040_150V_{i}.npz' for i in range(13)],
|
||||
|
||||
@@ -3,6 +3,13 @@ import torch.nn.functional as F
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
def get_model_class(version):
|
||||
class_name = f'singlePhotonNet_{version}'
|
||||
cls = globals().get(class_name)
|
||||
if cls is None:
|
||||
raise ValueError(f"Model class '{class_name}' not found.")
|
||||
return cls
|
||||
|
||||
class singlePhotonNet_250909(nn.Module):
|
||||
def weight_init(self):
|
||||
for m in self.modules():
|
||||
@@ -58,6 +65,36 @@ class singlePhotonNet_251020(nn.Module):
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
class singlePhotonNet_251022(nn.Module):
|
||||
'''
|
||||
Smaller input size (3x3)
|
||||
'''
|
||||
def weight_init(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def __init__(self):
|
||||
super(singlePhotonNet_251022, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 5, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv2d(5, 10, kernel_size=3, padding=1)
|
||||
self.conv3 = nn.Conv2d(10, 20, kernel_size=3, padding=1)
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Linear(20, 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.relu(self.conv3(x))
|
||||
x = self.gap(x).view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
class doublePhotonNet_250909(nn.Module):
|
||||
def __init__(self):
|
||||
super(doublePhotonNet_250909, self).__init__()
|
||||
|
||||
Reference in New Issue
Block a user