86 lines
2.9 KiB
Python
86 lines
2.9 KiB
Python
import sys
|
|
sys.path.append('./src')
|
|
import torch
|
|
import numpy as np
|
|
from models import *
|
|
from datasets import *
|
|
import torch.optim as optim
|
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
TrainLosses, TestLosses = [], []
|
|
|
|
def train(model, trainLoader, optimizer):
|
|
model.train()
|
|
batchLoss = 0
|
|
for batch_idx, (sample, label) in enumerate(trainLoader):
|
|
sample, label = sample.cuda(), label.cuda()
|
|
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
|
|
optimizer.zero_grad()
|
|
output = model(sample)
|
|
loss = torch.nn.functional.mse_loss(output, torch.stack((x, y), axis=1))
|
|
loss.backward()
|
|
optimizer.step()
|
|
batchLoss += loss.item() * sample.shape[0]
|
|
avgLoss = batchLoss / len(trainLoader.dataset)
|
|
print(f"[Train] Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})")
|
|
TrainLosses.append(avgLoss)
|
|
|
|
def test(model, testLoader):
|
|
model.eval()
|
|
batchLoss = 0
|
|
with torch.no_grad():
|
|
for batch_idx, (sample, label) in enumerate(testLoader):
|
|
sample, label = sample.cuda(), label.cuda()
|
|
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
|
|
output = model(sample)
|
|
loss = torch.nn.functional.mse_loss(output, torch.stack((x, y), axis=1))
|
|
batchLoss += loss.item() * sample.shape[0]
|
|
avgLoss = batchLoss / len(testLoader.dataset)
|
|
print(f"[Test] Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})")
|
|
TestLosses.append(avgLoss)
|
|
return avgLoss
|
|
|
|
model = singlePhotonNet_250909().cuda()
|
|
from glob import glob
|
|
sampleFileList = glob('/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_*_samples.npy')
|
|
labelFileList = glob('/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_*_labels.npy')
|
|
trainDataset = singlePhotonDataset(sampleFileList, labelFileList, sampleRatio=0.1)
|
|
nTrainSamples = int(0.8 * len(trainDataset))
|
|
nTestSamples = len(trainDataset) - nTrainSamples
|
|
trainDataset, testDataset = torch.utils.data.random_split(trainDataset, [nTrainSamples, nTestSamples])
|
|
|
|
trainLoader = torch.utils.data.DataLoader(
|
|
trainDataset,
|
|
batch_size=1024,
|
|
pin_memory = True,
|
|
shuffle=True,
|
|
num_workers=16
|
|
)
|
|
testLoader = torch.utils.data.DataLoader(
|
|
testDataset,
|
|
batch_size=4096,
|
|
shuffle=False,
|
|
num_workers=16
|
|
)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5, verbose = True)
|
|
if __name__ == "__main__":
|
|
for epoch in tqdm(range(1, 201)):
|
|
train(model, trainLoader, optimizer)
|
|
test(model, testLoader)
|
|
scheduler.step(TestLosses[-1])
|
|
|
|
torch.save(model.state_dict(), 'singlePhotonNet_250909.pth')
|
|
|
|
import matplotlib.pyplot as plt
|
|
plt.figure(figsize=(8,6))
|
|
plt.plot(TrainLosses, label='Train Loss')
|
|
plt.plot(TestLosses, label='Test Loss')
|
|
plt.yscale('log')
|
|
plt.xlabel('Epoch')
|
|
plt.ylabel('MSE Loss')
|
|
plt.legend()
|
|
plt.grid()
|
|
plt.savefig('loss_curve.png', dpi=300) |