Update 2Photon training
This commit is contained in:
+17
-4
@@ -12,13 +12,15 @@ from torchinfo import summary
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
modelVersion = '251001_2' # '250910' or '251001'
|
||||
modelVersion = '251124' # '250910' or '251001' or '251124'
|
||||
model = models.get_double_photon_model_class(modelVersion)().cuda()
|
||||
Energy = '15.3keV'
|
||||
Energy = '12keV'
|
||||
TrainLosses, ValLosses = [], []
|
||||
TestLoss = -1
|
||||
LearningRate = 1e-3
|
||||
Noise = 0.13 # in keV
|
||||
NoiseThreshold = 3*Noise
|
||||
Normalize = False
|
||||
|
||||
def two_point_set_loss_l2(pred_xy, gt_xy):
|
||||
def pair_cost_l2sq(p, q): # p,q: (...,2)
|
||||
@@ -118,6 +120,9 @@ trainDataset = doublePhotonDataset(
|
||||
datasetName='Train',
|
||||
reuselFactor=1,
|
||||
noiseKeV = Noise,
|
||||
nSize=7,
|
||||
noiseThreshold = NoiseThreshold,
|
||||
normalize = Normalize
|
||||
)
|
||||
valDataset = doublePhotonDataset(
|
||||
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)],
|
||||
@@ -125,6 +130,9 @@ valDataset = doublePhotonDataset(
|
||||
datasetName='Val',
|
||||
reuselFactor=1,
|
||||
noiseKeV = Noise,
|
||||
nSize=7,
|
||||
noiseThreshold = NoiseThreshold,
|
||||
normalize = Normalize
|
||||
)
|
||||
testDataset = doublePhotonDataset(
|
||||
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)],
|
||||
@@ -132,6 +140,9 @@ testDataset = doublePhotonDataset(
|
||||
datasetName='Test',
|
||||
reuselFactor=1,
|
||||
noiseKeV = Noise,
|
||||
nSize=7,
|
||||
noiseThreshold = NoiseThreshold,
|
||||
normalize = Normalize
|
||||
)
|
||||
trainLoader = torch.utils.data.DataLoader(
|
||||
trainDataset,
|
||||
@@ -161,10 +172,12 @@ if __name__ == "__main__":
|
||||
scheduler.step(ValLosses[-1])
|
||||
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
|
||||
if epoch in [20, 50, 100, 200, 300, 500, 750, 1000]:
|
||||
torch.save(model.state_dict(), f'Models/doublePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}.pth')
|
||||
modelName = f'doublePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}'
|
||||
if Normalize:
|
||||
modelName += '_normalized'
|
||||
torch.save(model.state_dict(), f'Models/{modelName}.pth')
|
||||
|
||||
test(model, testLoader)
|
||||
torch.save(model.state_dict(), f'Models/doublePhotonNet_{modelVersion}.pth')
|
||||
|
||||
def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
Reference in New Issue
Block a user