Add DL codes

This commit is contained in:
2025-10-08 08:00:24 +02:00
commit 3b758ac42e
4 changed files with 517 additions and 0 deletions
+78
View File
@@ -0,0 +1,78 @@
from torch.utils.data import Dataset
import torch
import numpy as np
class singlePhotonDataset(Dataset):
def __init__(self, sampleList, labelList, sampleRatio):
self.sampleFileList = sampleList
self.labelFileList = labelList
self.sampleRatio = sampleRatio
for idx, sampleFile in enumerate(self.sampleFileList):
if idx == 0:
self.samples = np.load(sampleFile)
self.labels = np.load(self.labelFileList[idx])
else:
self.samples = np.concatenate((self.samples, np.load(sampleFile)), axis=0)
self.labels = np.concatenate((self.labels, np.load(self.labelFileList[idx])), axis=0)
### total number of samples
self.length = int(self.samples.shape[0] * self.sampleRatio)
print(f"Total number of samples: {self.length}")
def __getitem__(self, index):
sample = self.samples[index]
sample = np.expand_dims(sample, axis=0)
label = self.labels[index]
return torch.tensor(sample, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)
def __len__(self):
return self.length
class doublePhotonDataset(Dataset):
def __init__(self, sampleList, labelList, sampleRatio, datasetName, reuselFactor=1):
self.sampleFileList = sampleList
self.labelFileList = labelList
self.sampleRatio = sampleRatio
self.datasetName = datasetName
for idx, sampleFile in enumerate(self.sampleFileList):
if idx == 0:
self.samples = np.load(sampleFile)
self.labels = np.load(self.labelFileList[idx])
else:
self.samples = np.concatenate((self.samples, np.load(sampleFile)), axis=0)
self.labels = np.concatenate((self.labels, np.load(self.labelFileList[idx])), axis=0)
### total number of samples
self.length = int(self.samples.shape[0] * self.sampleRatio) // 2 * reuselFactor
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}")
def __getitem__(self, index):
sample = np.zeros((8, 8), dtype=np.float32)
idx1 = np.random.randint(0, self.samples.shape[0])
idx2 = np.random.randint(0, self.samples.shape[0])
photon1 = self.samples[idx1]
photon2 = self.samples[idx2]
### random position for photons in
pos_x1 = np.random.randint(1, 4)
pos_y1 = np.random.randint(1, 4)
sample[pos_y1:pos_y1+5, pos_x1:pos_x1+5] += photon1
pos_x2 = np.random.randint(1, 4)
pos_y2 = np.random.randint(1, 4)
sample[pos_y2:pos_y2+5, pos_x2:pos_x2+5] += photon2
sample = sample[1:-1, 1:-1] ### sample size: 6x6
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0) / 1000. ### to keV
label1 = self.labels[idx1] + np.array([pos_x1, pos_y1, 0, 0]) - 1
label2 = self.labels[idx2] + np.array([pos_x2, pos_y2, 0, 0]) - 1
# label = np.concatenate((label1, label2), axis=0)
if label1[0] < label2[0]:
label = np.concatenate((label1, label2), axis=0)
else:
label = np.concatenate((label2, label1), axis=0)
return sample, torch.tensor(label, dtype=torch.float32)
def __len__(self):
return self.length