Add normalize, noiseThreshold options; add rms x/y outputs

This commit is contained in:
2026-03-18 08:32:35 +01:00
parent 9d7970856e
commit 572d798b72
2 changed files with 46 additions and 9 deletions
+32 -8
View File
@@ -16,7 +16,7 @@ torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
modelVersion = '251022' # '250909' or '251020'
Energy = '15.3keV'
Energy = '15keV'
TrainLosses, ValLosses = [], []
LearningRates = []
TestLoss = -1
@@ -24,7 +24,9 @@ model = models.get_model_class(modelVersion)().cuda()
# summary(model, input_size=(128, 1, 3, 3))
LearningRate = 1e-3
Noise = 0.13 # in keV
numberOfAugOps = 8 # 1 (no augmentation) or (1,8] (with augmentation)
NoiseThreshold = 0 * Noise # in keV, set values below this threshold to zero
numberOfAugOps = 1 # 1 (no augmentation) or (1,8] (with augmentation)
flag_normalize = False
TrainLosses, TestLosses = [], []
def weighted_loss(pred, target, alpha=7.0):
@@ -46,6 +48,7 @@ LossFunction = weighted_loss
def train(model, trainLoader, optimizer):
model.train()
batchLoss = 0
rms_x, rms_y = 0, 0
for batch_idx, (sample, label) in enumerate(trainLoader):
sample, label = sample.cuda(), label.cuda()
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
@@ -55,10 +58,14 @@ def train(model, trainLoader, optimizer):
loss.backward()
optimizer.step()
batchLoss += loss.item() * sample.shape[0]
rms_x += torch.sum((output[:,0] - x)**2).item()
rms_y += torch.sum((output[:,1] - y)**2).item()
avgLoss = batchLoss / len(trainLoader.dataset)
rms_x = np.sqrt(rms_x / len(trainLoader.dataset))
rms_y = np.sqrt(rms_y / len(trainLoader.dataset))
datasetName = trainLoader.dataset.datasetName
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})")
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
TrainLosses.append(avgLoss)
def test(model, testLoader):
@@ -89,6 +96,8 @@ trainDataset = singlePhotonDataset(
datasetName='Train',
noiseKeV = Noise,
numberOfAugOps=numberOfAugOps,
normalize=flag_normalize,
noiseThreshold=NoiseThreshold
)
valDataset = singlePhotonDataset(
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)],
@@ -96,12 +105,17 @@ valDataset = singlePhotonDataset(
datasetName='Val',
noiseKeV = Noise,
numberOfAugOps=numberOfAugOps,
normalize=flag_normalize,
noiseThreshold=NoiseThreshold
)
testDataset = singlePhotonDataset(
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)],
sampleRatio=1.0,
datasetName='Test',
noiseKeV = Noise,
numberOfAugOps=numberOfAugOps,
normalize=flag_normalize,
noiseThreshold=NoiseThreshold
)
trainLoader = torch.utils.data.DataLoader(
trainDataset,
@@ -133,11 +147,19 @@ if __name__ == "__main__":
test(model, valLoader)
scheduler.step(ValLosses[-1])
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
if epoch in [20, 50, 100, 200, 300, 500, 750, 1000]:
torch.save(model.state_dict(), f'Models/singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}_aug{numberOfAugOps}.pth')
if epoch in [20, 30, 50, 100, 200, 300, 500, 750, 1000]:
modelName = f'singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}_aug{numberOfAugOps}'
if flag_normalize == True:
modelName += '_normalized'
torch.save(model.state_dict(), f'Models/{modelName}.pth')
print(f"Saved model checkpoint: {modelName}.pth")
test(model, testLoader)
torch.save(model.state_dict(), f'Models/singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_aug{numberOfAugOps}.pth')
modelName = f'singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}_aug{numberOfAugOps}'
if flag_normalize == True:
modelName += '_normalized'
torch.save(model.state_dict(), f'Models/{modelName}.pth')
print(f"Saved final model checkpoint: {modelName}.pth")
def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion):
import matplotlib.pyplot as plt
@@ -151,6 +173,8 @@ def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion):
plt.ylabel('MSE Loss')
plt.legend()
plt.grid()
plt.savefig(f'Results/loss_curve_singlePhoton_{modelVersion}.png', dpi=300)
plotName = f'loss_curve_singlePhoton_{modelVersion}'
if flag_normalize:
plotName += '_normalized'
plt.savefig(f'Results/{plotName}.png')
plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion=modelVersion)