Weighted loss

This commit is contained in:
2025-10-27 18:20:32 +01:00
parent 15b621e359
commit 535e9f057a

View File

@@ -10,14 +10,37 @@ from torchinfo import summary
### random seed for reproducibility
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
modelVersion = '251022' # '250909' or '251020'
Energy = '15.3keV'
TrainLosses, ValLosses = [], []
LearningRates = []
TestLoss = -1
model = models.get_model_class(modelVersion)().cuda()
# summary(model, input_size=(128, 1, 3, 3))
LearningRate = 1e-3
Noise = 0.13 # in keV
TrainLosses, TestLosses = [], []
def weighted_loss(pred, target, alpha=7.0):
# weighted L1 loss for x,y position
pred = pred[:, :2]
target = target[:, :2]
# weights = 1.0 + alpha * torch.abs(target)
direction_weight = 1.0 + alpha * torch.abs(target) # (B, 2)
beta = 3.
r = torch.norm(target, dim=1, keepdim=True)
radial_weight = 1.0 + beta * r # (B, 1) →
weights = radial_weight * direction_weight
loss = weights * torch.abs(pred - target)
return loss.mean()
LossFunction = weighted_loss
def train(model, trainLoader, optimizer):
model.train()
@@ -27,7 +50,7 @@ def train(model, trainLoader, optimizer):
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
optimizer.zero_grad()
output = model(sample)
loss = torch.nn.functional.mse_loss(output, torch.stack((x, y), axis=1))
loss = LossFunction(output, torch.stack((x, y, z), axis=1))
loss.backward()
optimizer.step()
batchLoss += loss.item() * sample.shape[0]
@@ -45,7 +68,7 @@ def test(model, testLoader):
sample, label = sample.cuda(), label.cuda()
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
output = model(sample)
loss = torch.nn.functional.mse_loss(output, torch.stack((x, y), axis=1))
loss = LossFunction(output, torch.stack((x, y, z), axis=1))
batchLoss += loss.item() * sample.shape[0]
avgLoss = batchLoss / len(testLoader.dataset)
@@ -58,32 +81,28 @@ def test(model, testLoader):
TestLoss = avgLoss
return avgLoss
model = models.get_model_class(modelVersion)().cuda()
# summary(model, input_size=(128, 1, 3, 3))
sampleFolder = '/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040'
trainDataset = singlePhotonDataset(
[f'{sampleFolder}/15keV_Moench040_150V_{i}.npz' for i in range(13)],
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13)],
sampleRatio=1.0,
datasetName='Train',
noiseKeV = Noise,
)
valDataset = singlePhotonDataset(
[f'{sampleFolder}/15keV_Moench040_150V_{i}.npz' for i in range(13,14)],
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)],
sampleRatio=1.0,
datasetName='Val',
noiseKeV = Noise,
)
testDataset = singlePhotonDataset(
[f'{sampleFolder}/15keV_Moench040_150V_{i}.npz' for i in range(15,16)],
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)],
sampleRatio=1.0,
datasetName='Test',
noiseKeV = Noise,
)
trainLoader = torch.utils.data.DataLoader(
trainDataset,
batch_size=1024,
batch_size=4096,
shuffle=True,
num_workers=16,
pin_memory=True,
@@ -103,16 +122,19 @@ testLoader = torch.utils.data.DataLoader(
pin_memory=True,
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5)
optimizer = torch.optim.Adam(model.parameters(), lr=LearningRate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 3)
if __name__ == "__main__":
for epoch in tqdm(range(1, 101)):
for epoch in tqdm(range(1, 301)):
train(model, trainLoader, optimizer)
test(model, valLoader)
scheduler.step(TrainLosses[-1])
scheduler.step(ValLosses[-1])
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
if epoch in [20, 50, 100, 200, 300, 500]:
torch.save(model.state_dict(), f'Models/singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}.pth')
test(model, testLoader)
torch.save(model.state_dict(), f'Models/singlePhotonNet_Noise{Noise}keV_{modelVersion}.pth')
torch.save(model.state_dict(), f'Models/singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV.pth')
def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion):
import matplotlib.pyplot as plt