Augment for single photon
This commit is contained in:
@@ -24,6 +24,7 @@ model = models.get_model_class(modelVersion)().cuda()
|
||||
# summary(model, input_size=(128, 1, 3, 3))
|
||||
LearningRate = 1e-3
|
||||
Noise = 0.13 # in keV
|
||||
numberOfAugOps = 8 # 1 (no augmentation) or (1,8] (with augmentation)
|
||||
|
||||
TrainLosses, TestLosses = [], []
|
||||
def weighted_loss(pred, target, alpha=7.0):
|
||||
@@ -87,12 +88,14 @@ trainDataset = singlePhotonDataset(
|
||||
sampleRatio=1.0,
|
||||
datasetName='Train',
|
||||
noiseKeV = Noise,
|
||||
numberOfAugOps=numberOfAugOps,
|
||||
)
|
||||
valDataset = singlePhotonDataset(
|
||||
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)],
|
||||
sampleRatio=1.0,
|
||||
datasetName='Val',
|
||||
noiseKeV = Noise,
|
||||
numberOfAugOps=numberOfAugOps,
|
||||
)
|
||||
testDataset = singlePhotonDataset(
|
||||
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)],
|
||||
@@ -104,37 +107,37 @@ trainLoader = torch.utils.data.DataLoader(
|
||||
trainDataset,
|
||||
batch_size=4096,
|
||||
shuffle=True,
|
||||
num_workers=16,
|
||||
num_workers=32,
|
||||
pin_memory=True,
|
||||
)
|
||||
valLoader = torch.utils.data.DataLoader(
|
||||
valDataset,
|
||||
batch_size=1024,
|
||||
shuffle=False,
|
||||
num_workers=16,
|
||||
num_workers=32,
|
||||
pin_memory=True,
|
||||
)
|
||||
testLoader = torch.utils.data.DataLoader(
|
||||
testDataset,
|
||||
batch_size=1024,
|
||||
shuffle=4096,
|
||||
num_workers=16,
|
||||
num_workers=32,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=LearningRate)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 3)
|
||||
if __name__ == "__main__":
|
||||
for epoch in tqdm(range(1, 301)):
|
||||
for epoch in tqdm(range(1, 1001)):
|
||||
train(model, trainLoader, optimizer)
|
||||
test(model, valLoader)
|
||||
scheduler.step(ValLosses[-1])
|
||||
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
|
||||
if epoch in [20, 50, 100, 200, 300, 500]:
|
||||
torch.save(model.state_dict(), f'Models/singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}.pth')
|
||||
if epoch in [20, 50, 100, 200, 300, 500, 750, 1000]:
|
||||
torch.save(model.state_dict(), f'Models/singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}_aug{numberOfAugOps}.pth')
|
||||
|
||||
test(model, testLoader)
|
||||
torch.save(model.state_dict(), f'Models/singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV.pth')
|
||||
torch.save(model.state_dict(), f'Models/singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_aug{numberOfAugOps}.pth')
|
||||
|
||||
def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
Reference in New Issue
Block a user