From 1f7998c77a2e96cfe0dd0a14e99f2b9c0dd5b726 Mon Sep 17 00:00:00 2001 From: "xiangyu.xie" Date: Wed, 22 Oct 2025 08:03:15 +0200 Subject: [PATCH] Set random seed for reproducibility --- Train_doublePhoton.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/Train_doublePhoton.py b/Train_doublePhoton.py index 911ba29..50c9ef6 100644 --- a/Train_doublePhoton.py +++ b/Train_doublePhoton.py @@ -6,6 +6,11 @@ from models import * from datasets import * import torch.optim as optim from tqdm import tqdm +from torchinfo import summary + +### random seed for reproducibility +torch.manual_seed(0) +np.random.seed(0) modelVersion = '251001_2' # '250910' or '251001' TrainLosses, ValLosses = [], [] @@ -45,6 +50,8 @@ elif modelVersion == '251001_2': loss_fn = min_matching_loss model = doublePhotonNet_251001_2().cuda() +# summary(model, input_size=(128, 1, 6, 6)) ### print model summary + def train(model, trainLoader, optimizer): model.train() batchLoss = 0 @@ -142,7 +149,7 @@ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, p # test(model, valLoader) # scheduler.step(TrainLosses[-1]) -model.load_state_dict(torch.load('doublePhotonNet_251001_2.pth', weights_only=True)) +model.load_state_dict(torch.load(f'doublePhotonNet_{modelVersion}.pth', weights_only=True)) test(model, testLoader) torch.save(model.state_dict(), f'doublePhotonNet_{modelVersion}.pth')