Add DL codes
This commit is contained in:
@@ -0,0 +1,78 @@
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
class singlePhotonDataset(Dataset):
|
||||
def __init__(self, sampleList, labelList, sampleRatio):
|
||||
self.sampleFileList = sampleList
|
||||
self.labelFileList = labelList
|
||||
self.sampleRatio = sampleRatio
|
||||
|
||||
for idx, sampleFile in enumerate(self.sampleFileList):
|
||||
if idx == 0:
|
||||
self.samples = np.load(sampleFile)
|
||||
self.labels = np.load(self.labelFileList[idx])
|
||||
else:
|
||||
self.samples = np.concatenate((self.samples, np.load(sampleFile)), axis=0)
|
||||
self.labels = np.concatenate((self.labels, np.load(self.labelFileList[idx])), axis=0)
|
||||
|
||||
### total number of samples
|
||||
self.length = int(self.samples.shape[0] * self.sampleRatio)
|
||||
print(f"Total number of samples: {self.length}")
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
sample = np.expand_dims(sample, axis=0)
|
||||
label = self.labels[index]
|
||||
return torch.tensor(sample, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
class doublePhotonDataset(Dataset):
|
||||
def __init__(self, sampleList, labelList, sampleRatio, datasetName, reuselFactor=1):
|
||||
self.sampleFileList = sampleList
|
||||
self.labelFileList = labelList
|
||||
self.sampleRatio = sampleRatio
|
||||
self.datasetName = datasetName
|
||||
|
||||
for idx, sampleFile in enumerate(self.sampleFileList):
|
||||
if idx == 0:
|
||||
self.samples = np.load(sampleFile)
|
||||
self.labels = np.load(self.labelFileList[idx])
|
||||
else:
|
||||
self.samples = np.concatenate((self.samples, np.load(sampleFile)), axis=0)
|
||||
self.labels = np.concatenate((self.labels, np.load(self.labelFileList[idx])), axis=0)
|
||||
|
||||
### total number of samples
|
||||
self.length = int(self.samples.shape[0] * self.sampleRatio) // 2 * reuselFactor
|
||||
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}")
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = np.zeros((8, 8), dtype=np.float32)
|
||||
idx1 = np.random.randint(0, self.samples.shape[0])
|
||||
idx2 = np.random.randint(0, self.samples.shape[0])
|
||||
photon1 = self.samples[idx1]
|
||||
photon2 = self.samples[idx2]
|
||||
|
||||
### random position for photons in
|
||||
pos_x1 = np.random.randint(1, 4)
|
||||
pos_y1 = np.random.randint(1, 4)
|
||||
sample[pos_y1:pos_y1+5, pos_x1:pos_x1+5] += photon1
|
||||
pos_x2 = np.random.randint(1, 4)
|
||||
pos_y2 = np.random.randint(1, 4)
|
||||
sample[pos_y2:pos_y2+5, pos_x2:pos_x2+5] += photon2
|
||||
sample = sample[1:-1, 1:-1] ### sample size: 6x6
|
||||
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0) / 1000. ### to keV
|
||||
|
||||
label1 = self.labels[idx1] + np.array([pos_x1, pos_y1, 0, 0]) - 1
|
||||
label2 = self.labels[idx2] + np.array([pos_x2, pos_y2, 0, 0]) - 1
|
||||
# label = np.concatenate((label1, label2), axis=0)
|
||||
if label1[0] < label2[0]:
|
||||
label = np.concatenate((label1, label2), axis=0)
|
||||
else:
|
||||
label = np.concatenate((label2, label1), axis=0)
|
||||
return sample, torch.tensor(label, dtype=torch.float32)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
Reference in New Issue
Block a user