131 lines
4.1 KiB
Python
131 lines
4.1 KiB
Python
import sys
|
|
sys.path.append('./src')
|
|
import torch
|
|
import numpy as np
|
|
import models
|
|
from datasets import *
|
|
import torch.optim as optim
|
|
from tqdm import tqdm
|
|
from torchinfo import summary
|
|
|
|
### random seed for reproducibility
|
|
torch.manual_seed(0)
|
|
np.random.seed(0)
|
|
|
|
modelVersion = '251022' # '250909' or '251020'
|
|
TrainLosses, ValLosses = [], []
|
|
TestLoss = -1
|
|
Noise = 0.13 # in keV
|
|
|
|
TrainLosses, TestLosses = [], []
|
|
|
|
def train(model, trainLoader, optimizer):
|
|
model.train()
|
|
batchLoss = 0
|
|
for batch_idx, (sample, label) in enumerate(trainLoader):
|
|
sample, label = sample.cuda(), label.cuda()
|
|
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
|
|
optimizer.zero_grad()
|
|
output = model(sample)
|
|
loss = torch.nn.functional.mse_loss(output, torch.stack((x, y), axis=1))
|
|
loss.backward()
|
|
optimizer.step()
|
|
batchLoss += loss.item() * sample.shape[0]
|
|
avgLoss = batchLoss / len(trainLoader.dataset)
|
|
|
|
datasetName = trainLoader.dataset.datasetName
|
|
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})")
|
|
TrainLosses.append(avgLoss)
|
|
|
|
def test(model, testLoader):
|
|
model.eval()
|
|
batchLoss = 0
|
|
with torch.no_grad():
|
|
for batch_idx, (sample, label) in enumerate(testLoader):
|
|
sample, label = sample.cuda(), label.cuda()
|
|
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
|
|
output = model(sample)
|
|
loss = torch.nn.functional.mse_loss(output, torch.stack((x, y), axis=1))
|
|
batchLoss += loss.item() * sample.shape[0]
|
|
avgLoss = batchLoss / len(testLoader.dataset)
|
|
|
|
datasetName = testLoader.dataset.datasetName
|
|
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})")
|
|
if datasetName == 'Val':
|
|
ValLosses.append(avgLoss)
|
|
else:
|
|
global TestLoss
|
|
TestLoss = avgLoss
|
|
return avgLoss
|
|
|
|
model = models.get_model_class(modelVersion)().cuda()
|
|
# summary(model, input_size=(128, 1, 3, 3))
|
|
|
|
sampleFolder = '/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040'
|
|
trainDataset = singlePhotonDataset(
|
|
[f'{sampleFolder}/15keV_Moench040_150V_{i}.npz' for i in range(13)],
|
|
sampleRatio=1.0,
|
|
datasetName='Train',
|
|
noiseKeV = Noise,
|
|
)
|
|
valDataset = singlePhotonDataset(
|
|
[f'{sampleFolder}/15keV_Moench040_150V_{i}.npz' for i in range(13,14)],
|
|
sampleRatio=1.0,
|
|
datasetName='Val',
|
|
noiseKeV = Noise,
|
|
)
|
|
testDataset = singlePhotonDataset(
|
|
[f'{sampleFolder}/15keV_Moench040_150V_{i}.npz' for i in range(15,16)],
|
|
sampleRatio=1.0,
|
|
datasetName='Test',
|
|
noiseKeV = Noise,
|
|
)
|
|
|
|
trainLoader = torch.utils.data.DataLoader(
|
|
trainDataset,
|
|
batch_size=1024,
|
|
shuffle=True,
|
|
num_workers=16,
|
|
pin_memory=True,
|
|
)
|
|
valLoader = torch.utils.data.DataLoader(
|
|
valDataset,
|
|
batch_size=1024,
|
|
shuffle=False,
|
|
num_workers=16,
|
|
pin_memory=True,
|
|
)
|
|
testLoader = torch.utils.data.DataLoader(
|
|
testDataset,
|
|
batch_size=1024,
|
|
shuffle=4096,
|
|
num_workers=16,
|
|
pin_memory=True,
|
|
)
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5)
|
|
if __name__ == "__main__":
|
|
for epoch in tqdm(range(1, 101)):
|
|
train(model, trainLoader, optimizer)
|
|
test(model, valLoader)
|
|
scheduler.step(TrainLosses[-1])
|
|
|
|
test(model, testLoader)
|
|
torch.save(model.state_dict(), f'Models/singlePhotonNet_Noise{Noise}keV_{modelVersion}.pth')
|
|
|
|
def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion):
|
|
import matplotlib.pyplot as plt
|
|
plt.figure(figsize=(8,6))
|
|
plt.plot(TrainLosses, label='Train Loss', color='blue')
|
|
plt.plot(ValLosses, label='Validation Loss', color='orange')
|
|
if TestLoss > 0:
|
|
plt.axhline(y=TestLoss, color='green', linestyle='--', label='Test Loss')
|
|
plt.yscale('log')
|
|
plt.xlabel('Epoch')
|
|
plt.ylabel('MSE Loss')
|
|
plt.legend()
|
|
plt.grid()
|
|
plt.savefig(f'Results/loss_curve_singlePhoton_{modelVersion}.png', dpi=300)
|
|
|
|
plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion=modelVersion) |