From e80fa0362b71b6fdaa71eb6f6eaf9b94646e8470 Mon Sep 17 00:00:00 2001 From: "xiangyu.xie" Date: Wed, 22 Oct 2025 08:02:11 +0200 Subject: [PATCH] Add h5 support; add refpoints --- src/datasets.py | 73 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/src/datasets.py b/src/datasets.py index b217bb4..215b7ee 100644 --- a/src/datasets.py +++ b/src/datasets.py @@ -3,22 +3,55 @@ import torch import numpy as np class singlePhotonDataset(Dataset): - def __init__(self, sampleList, labelList, sampleRatio): + def __init__(self, sampleList, sampleRatio, datasetName, noiseKeV=0): self.sampleFileList = sampleList - self.labelFileList = labelList self.sampleRatio = sampleRatio + self.datasetName = datasetName + + all_samples = [] + all_labels = [] + all_ref_pts = [] + for sampleFile in self.sampleFileList: + if '.npz' in sampleFile: + data = np.load(sampleFile) + all_samples.append(data['samples']) + if 'referencePoint' in data: + all_ref_pts.append(data['referencePoint']) + else: + all_ref_pts.append(np.zeros((data['samples'].shape[0], 2), dtype=np.float32)) ### dummy reference points + if 'labels' in data: + all_labels.append(data['labels']) + else: + all_labels.append(np.zeros((data['samples'].shape[0], 4), dtype=np.float32)) ### dummy labels + elif '.h5' in sampleFile: + import h5py + with h5py.File(sampleFile, 'r') as f: + samples = f['clusters'][:] + if 'referencePoint' in f: + ref_pts = f['referencePoint'][:] + all_ref_pts.append(ref_pts) + else: + all_ref_pts.append(np.zeros((samples.shape[0], 2), dtype=np.float32)) ### dummy reference points + if 'labels' in f: + labels = f['labels'][:] + else: + labels = np.zeros((samples.shape[0], 4), dtype=np.float32) ### dummy labels + all_samples.append(samples) + all_labels.append(labels) + self.samples = np.concatenate(all_samples, axis=0) + if noiseKeV != 0: + print(f'Adding Gaussian noise with sigma = {noiseKeV} keV to samples in {self.datasetName} dataset') + noise = np.random.normal(loc=0.0, scale=noiseKeV, size=self.samples.shape) + self.samples = self.samples + noise + self.labels = np.concatenate(all_labels, axis=0) + self.referencePoint = np.concatenate(all_ref_pts, axis=0) if all_ref_pts else None - for idx, sampleFile in enumerate(self.sampleFileList): - if idx == 0: - self.samples = np.load(sampleFile) - self.labels = np.load(self.labelFileList[idx]) - else: - self.samples = np.concatenate((self.samples, np.load(sampleFile)), axis=0) - self.labels = np.concatenate((self.labels, np.load(self.labelFileList[idx])), axis=0) - + if self.samples.shape[1] == 5: ### if sample size is 5x5, remove border pixels to make it 3x3 + self.samples = self.samples[:, 1:-1, 1:-1] ### remove border pixels + self.labels = self.labels - np.array([1, 1, 0, 0]) ### adjust labels accordingly ### total number of samples self.length = int(self.samples.shape[0] * self.sampleRatio) - print(f"Total number of samples: {self.length}") + print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}") def __getitem__(self, index): sample = self.samples[index] @@ -30,20 +63,20 @@ class singlePhotonDataset(Dataset): return self.length class doublePhotonDataset(Dataset): - def __init__(self, sampleList, labelList, sampleRatio, datasetName, reuselFactor=1): + def __init__(self, sampleList, sampleRatio, datasetName, reuselFactor=1): self.sampleFileList = sampleList - self.labelFileList = labelList self.sampleRatio = sampleRatio self.datasetName = datasetName + all_samples = [] + all_labels = [] for idx, sampleFile in enumerate(self.sampleFileList): - if idx == 0: - self.samples = np.load(sampleFile) - self.labels = np.load(self.labelFileList[idx]) - else: - self.samples = np.concatenate((self.samples, np.load(sampleFile)), axis=0) - self.labels = np.concatenate((self.labels, np.load(self.labelFileList[idx])), axis=0) - + data = np.load(sampleFile) + all_samples.append(data['samples']) + all_labels.append(data['labels']) + self.samples = np.concatenate(all_samples, axis=0) + self.labels = np.concatenate(all_labels, axis=0) + ### total number of samples self.length = int(self.samples.shape[0] * self.sampleRatio) // 2 * reuselFactor print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}")