Set random seed for reproducibility
This commit is contained in:
@@ -6,6 +6,11 @@ from models import *
|
||||
from datasets import *
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
from torchinfo import summary
|
||||
|
||||
### random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
modelVersion = '251001_2' # '250910' or '251001'
|
||||
TrainLosses, ValLosses = [], []
|
||||
@@ -45,6 +50,8 @@ elif modelVersion == '251001_2':
|
||||
loss_fn = min_matching_loss
|
||||
model = doublePhotonNet_251001_2().cuda()
|
||||
|
||||
# summary(model, input_size=(128, 1, 6, 6)) ### print model summary
|
||||
|
||||
def train(model, trainLoader, optimizer):
|
||||
model.train()
|
||||
batchLoss = 0
|
||||
@@ -142,7 +149,7 @@ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, p
|
||||
# test(model, valLoader)
|
||||
# scheduler.step(TrainLosses[-1])
|
||||
|
||||
model.load_state_dict(torch.load('doublePhotonNet_251001_2.pth', weights_only=True))
|
||||
model.load_state_dict(torch.load(f'doublePhotonNet_{modelVersion}.pth', weights_only=True))
|
||||
test(model, testLoader)
|
||||
torch.save(model.state_dict(), f'doublePhotonNet_{modelVersion}.pth')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user