Rename single photon files
This commit is contained in:
@@ -0,0 +1,182 @@
|
||||
import sys
|
||||
sys.path.append('./src')
|
||||
import torch
|
||||
import numpy as np
|
||||
import models
|
||||
from datasets import *
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
from torchinfo import summary
|
||||
|
||||
### random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
modelVersion = '251022' # '250909' or '251020'
|
||||
Energy = '15keV'
|
||||
TrainLosses, ValLosses = [], []
|
||||
LearningRates = []
|
||||
TestLoss = -1
|
||||
model = models.get_model_class(modelVersion)().cuda()
|
||||
# summary(model, input_size=(128, 1, 3, 3))
|
||||
LearningRate = 1e-3
|
||||
Noise = 0.23 # in keV
|
||||
NoiseThreshold = 0 * Noise # in keV, set values below this threshold to zero
|
||||
numberOfAugOps = 1 # 1 (no augmentation) or (1,8] (with augmentation)
|
||||
flag_normalize = False
|
||||
|
||||
TrainLosses, TestLosses = [], []
|
||||
def weighted_loss(pred, target, alpha=7.0):
|
||||
# weighted L1 loss for x,y position
|
||||
pred = pred[:, :2]
|
||||
target = target[:, :2]
|
||||
|
||||
# weights = 1.0 + alpha * torch.abs(target)
|
||||
direction_weight = 1.0 + alpha * torch.abs(target) # (B, 2)
|
||||
beta = 3.
|
||||
r = torch.norm(target, dim=1, keepdim=True)
|
||||
radial_weight = 1.0 + beta * r # (B, 1) →
|
||||
weights = radial_weight * direction_weight
|
||||
|
||||
loss = weights * torch.abs(pred - target)
|
||||
return loss.mean()
|
||||
LossFunction = weighted_loss
|
||||
|
||||
def train(model, trainLoader, optimizer):
|
||||
model.train()
|
||||
batchLoss = 0
|
||||
rms_x, rms_y = 0, 0
|
||||
for batch_idx, (sample, label) in enumerate(trainLoader):
|
||||
sample, label = sample.cuda(), label.cuda()
|
||||
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
|
||||
optimizer.zero_grad()
|
||||
output = model(sample)
|
||||
loss = LossFunction(output, torch.stack((x, y, z), axis=1))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
batchLoss += loss.item() * sample.shape[0]
|
||||
rms_x += torch.sum((output[:,0] - x)**2).item()
|
||||
rms_y += torch.sum((output[:,1] - y)**2).item()
|
||||
avgLoss = batchLoss / len(trainLoader.dataset)
|
||||
rms_x = np.sqrt(rms_x / len(trainLoader.dataset))
|
||||
rms_y = np.sqrt(rms_y / len(trainLoader.dataset))
|
||||
|
||||
datasetName = trainLoader.dataset.datasetName
|
||||
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
|
||||
TrainLosses.append(avgLoss)
|
||||
|
||||
def test(model, testLoader):
|
||||
model.eval()
|
||||
batchLoss = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (sample, label) in enumerate(testLoader):
|
||||
sample, label = sample.cuda(), label.cuda()
|
||||
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
|
||||
output = model(sample)
|
||||
loss = LossFunction(output, torch.stack((x, y, z), axis=1))
|
||||
batchLoss += loss.item() * sample.shape[0]
|
||||
avgLoss = batchLoss / len(testLoader.dataset)
|
||||
|
||||
datasetName = testLoader.dataset.datasetName
|
||||
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})")
|
||||
if datasetName == 'Val':
|
||||
ValLosses.append(avgLoss)
|
||||
else:
|
||||
global TestLoss
|
||||
TestLoss = avgLoss
|
||||
return avgLoss
|
||||
|
||||
sampleFolder = '/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040'
|
||||
trainDataset = singlePhotonDataset(
|
||||
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13)],
|
||||
sampleRatio=1.0,
|
||||
datasetName='Train',
|
||||
noiseKeV = Noise,
|
||||
numberOfAugOps=numberOfAugOps,
|
||||
normalize=flag_normalize,
|
||||
noiseThreshold=NoiseThreshold
|
||||
)
|
||||
valDataset = singlePhotonDataset(
|
||||
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)],
|
||||
sampleRatio=1.0,
|
||||
datasetName='Val',
|
||||
noiseKeV = Noise,
|
||||
numberOfAugOps=numberOfAugOps,
|
||||
normalize=flag_normalize,
|
||||
noiseThreshold=NoiseThreshold
|
||||
)
|
||||
testDataset = singlePhotonDataset(
|
||||
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)],
|
||||
sampleRatio=1.0,
|
||||
datasetName='Test',
|
||||
noiseKeV = Noise,
|
||||
numberOfAugOps=numberOfAugOps,
|
||||
normalize=flag_normalize,
|
||||
noiseThreshold=NoiseThreshold
|
||||
)
|
||||
trainLoader = torch.utils.data.DataLoader(
|
||||
trainDataset,
|
||||
batch_size=4096,
|
||||
shuffle=True,
|
||||
num_workers=32,
|
||||
pin_memory=True,
|
||||
)
|
||||
valLoader = torch.utils.data.DataLoader(
|
||||
valDataset,
|
||||
batch_size=1024,
|
||||
shuffle=False,
|
||||
num_workers=32,
|
||||
pin_memory=True,
|
||||
)
|
||||
testLoader = torch.utils.data.DataLoader(
|
||||
testDataset,
|
||||
batch_size=1024,
|
||||
shuffle=4096,
|
||||
num_workers=32,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=LearningRate)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 3)
|
||||
|
||||
def plot_loss_curve(TrainLosses, ValLosses, modelVersion, TestLoss=0):
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure(figsize=(8,6))
|
||||
plt.plot(TrainLosses, label='Train Loss', color='blue')
|
||||
plt.plot(ValLosses, label='Validation Loss', color='orange')
|
||||
if TestLoss > 0:
|
||||
plt.axhline(y=TestLoss, color='green', linestyle='--', label='Test Loss')
|
||||
plt.yscale('log')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('MSE Loss')
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
plotName = f'loss_curve_singlePhoton_{modelVersion}'
|
||||
if flag_normalize:
|
||||
plotName += '_normalized'
|
||||
plt.savefig(f'Results/{plotName}.png')
|
||||
|
||||
if __name__ == "__main__":
|
||||
for epoch in tqdm(range(1, 151)):
|
||||
train(model, trainLoader, optimizer)
|
||||
test(model, valLoader)
|
||||
scheduler.step(ValLosses[-1])
|
||||
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
|
||||
if epoch in [20, 30, 50, 100, 150]:
|
||||
modelName = f'singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}_aug{numberOfAugOps}'
|
||||
if flag_normalize == True:
|
||||
modelName += '_normalized'
|
||||
torch.save(model.state_dict(), f'Models/{modelName}.pth')
|
||||
print(f"Saved model checkpoint: {modelName}.pth")
|
||||
plot_loss_curve(TrainLosses, ValLosses, modelVersion=modelVersion)
|
||||
|
||||
test(model, testLoader)
|
||||
modelName = f'singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}_aug{numberOfAugOps}'
|
||||
if flag_normalize == True:
|
||||
modelName += '_normalized'
|
||||
torch.save(model.state_dict(), f'Models/{modelName}.pth')
|
||||
print(f"Saved final model checkpoint: {modelName}.pth")
|
||||
plot_loss_curve(TrainLosses, ValLosses, modelVersion=modelVersion, TestLoss=TestLoss)
|
||||
Reference in New Issue
Block a user