This commit is contained in:
2025-10-22 08:03:51 +02:00
parent 1f7998c77a
commit 324b81fc1b
3 changed files with 130 additions and 86 deletions

View File

@@ -1,86 +0,0 @@
import sys
sys.path.append('./src')
import torch
import numpy as np
from models import *
from datasets import *
import torch.optim as optim
from tqdm import tqdm
TrainLosses, TestLosses = [], []
def train(model, trainLoader, optimizer):
model.train()
batchLoss = 0
for batch_idx, (sample, label) in enumerate(trainLoader):
sample, label = sample.cuda(), label.cuda()
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
optimizer.zero_grad()
output = model(sample)
loss = torch.nn.functional.mse_loss(output, torch.stack((x, y), axis=1))
loss.backward()
optimizer.step()
batchLoss += loss.item() * sample.shape[0]
avgLoss = batchLoss / len(trainLoader.dataset)
print(f"[Train] Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})")
TrainLosses.append(avgLoss)
def test(model, testLoader):
model.eval()
batchLoss = 0
with torch.no_grad():
for batch_idx, (sample, label) in enumerate(testLoader):
sample, label = sample.cuda(), label.cuda()
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
output = model(sample)
loss = torch.nn.functional.mse_loss(output, torch.stack((x, y), axis=1))
batchLoss += loss.item() * sample.shape[0]
avgLoss = batchLoss / len(testLoader.dataset)
print(f"[Test] Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})")
TestLosses.append(avgLoss)
return avgLoss
model = singlePhotonNet_250909().cuda()
from glob import glob
sampleFileList = glob('/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_*_samples.npy')
labelFileList = glob('/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_*_labels.npy')
trainDataset = singlePhotonDataset(sampleFileList, labelFileList, sampleRatio=0.1)
nTrainSamples = int(0.8 * len(trainDataset))
nTestSamples = len(trainDataset) - nTrainSamples
trainDataset, testDataset = torch.utils.data.random_split(trainDataset, [nTrainSamples, nTestSamples])
trainLoader = torch.utils.data.DataLoader(
trainDataset,
batch_size=1024,
pin_memory = True,
shuffle=True,
num_workers=16
)
testLoader = torch.utils.data.DataLoader(
testDataset,
batch_size=4096,
shuffle=False,
num_workers=16
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5, verbose = True)
if __name__ == "__main__":
for epoch in tqdm(range(1, 201)):
train(model, trainLoader, optimizer)
test(model, testLoader)
scheduler.step(TestLosses[-1])
torch.save(model.state_dict(), 'singlePhotonNet_250909.pth')
import matplotlib.pyplot as plt
plt.figure(figsize=(8,6))
plt.plot(TrainLosses, label='Train Loss')
plt.plot(TestLosses, label='Test Loss')
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.grid()
plt.savefig('loss_curve.png', dpi=300)

130
Train_SinglePhoton.py Normal file
View File

@@ -0,0 +1,130 @@
import sys
sys.path.append('./src')
import torch
import numpy as np
from models import *
from datasets import *
import torch.optim as optim
from tqdm import tqdm
from torchinfo import summary
### random seed for reproducibility
torch.manual_seed(0)
np.random.seed(0)
modelVersion = '251020' # '250909' or '251020'
TrainLosses, ValLosses = [], []
TestLoss = -1
Noise = 0.13 # in keV
TrainLosses, TestLosses = [], []
def train(model, trainLoader, optimizer):
model.train()
batchLoss = 0
for batch_idx, (sample, label) in enumerate(trainLoader):
sample, label = sample.cuda(), label.cuda()
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
optimizer.zero_grad()
output = model(sample)
loss = torch.nn.functional.mse_loss(output, torch.stack((x, y), axis=1))
loss.backward()
optimizer.step()
batchLoss += loss.item() * sample.shape[0]
avgLoss = batchLoss / len(trainLoader.dataset)
datasetName = trainLoader.dataset.datasetName
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})")
TrainLosses.append(avgLoss)
def test(model, testLoader):
model.eval()
batchLoss = 0
with torch.no_grad():
for batch_idx, (sample, label) in enumerate(testLoader):
sample, label = sample.cuda(), label.cuda()
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
output = model(sample)
loss = torch.nn.functional.mse_loss(output, torch.stack((x, y), axis=1))
batchLoss += loss.item() * sample.shape[0]
avgLoss = batchLoss / len(testLoader.dataset)
datasetName = testLoader.dataset.datasetName
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})")
if datasetName == 'Val':
ValLosses.append(avgLoss)
else:
global TestLoss
TestLoss = avgLoss
return avgLoss
model = singlePhotonNet_251020().cuda()
# summary(model, input_size=(128, 1, 3, 3))
sampleFolder = '/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040'
trainDataset = singlePhotonDataset(
[f'{sampleFolder}/15keV_Moench040_150V_{i}.npz' for i in range(13)],
sampleRatio=1.0,
datasetName='Train',
noiseKeV = Noise,
)
valDataset = singlePhotonDataset(
[f'{sampleFolder}/15keV_Moench040_150V_{i}.npz' for i in range(13,14)],
sampleRatio=1.0,
datasetName='Val',
noiseKeV = Noise,
)
testDataset = singlePhotonDataset(
[f'{sampleFolder}/15keV_Moench040_150V_{i}.npz' for i in range(15,16)],
sampleRatio=1.0,
datasetName='Test',
noiseKeV = Noise,
)
trainLoader = torch.utils.data.DataLoader(
trainDataset,
batch_size=1024,
shuffle=True,
num_workers=16,
pin_memory=True,
)
valLoader = torch.utils.data.DataLoader(
valDataset,
batch_size=1024,
shuffle=False,
num_workers=16,
pin_memory=True,
)
testLoader = torch.utils.data.DataLoader(
testDataset,
batch_size=1024,
shuffle=4096,
num_workers=16,
pin_memory=True,
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5)
if __name__ == "__main__":
for epoch in tqdm(range(1, 101)):
train(model, trainLoader, optimizer)
test(model, valLoader)
scheduler.step(TrainLosses[-1])
test(model, testLoader)
torch.save(model.state_dict(), f'Models/singlePhotonNet_Noise{Noise}keV_{modelVersion}.pth')
def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion):
import matplotlib.pyplot as plt
plt.figure(figsize=(8,6))
plt.plot(TrainLosses, label='Train Loss', color='blue')
plt.plot(ValLosses, label='Validation Loss', color='orange')
if TestLoss > 0:
plt.axhline(y=TestLoss, color='green', linestyle='--', label='Test Loss')
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.grid()
plt.savefig(f'Results/loss_curve_singlePhoton_{modelVersion}.png', dpi=300)
plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion=modelVersion)