diff --git a/Inference_SinglePhoton.py b/Inference_SinglePhoton.py index b610d7c..0b12542 100644 --- a/Inference_SinglePhoton.py +++ b/Inference_SinglePhoton.py @@ -5,6 +5,8 @@ import models from datasets import * from tqdm import tqdm from matplotlib import pyplot as plt +import numpy as np +import h5py torch.manual_seed(42) torch.cuda.manual_seed(42) @@ -85,12 +87,40 @@ if __name__ == "__main__": referencePoints.append(dataset.referencePoint) + _chunk_predictions = [] with torch.no_grad(): for batch in tqdm(dataLoader): inputs, _ = batch - inputs = inputs.cuda() - outputs = model(inputs)[:, :2] # only x and y - predictions.append(outputs.cpu()) + inputs_cuda = inputs.cuda() + outputs = model(inputs_cuda)[:, :2].cpu() # only x and y + _chunk_predictions.append(outputs) + predictions.extend(_chunk_predictions) + ### save samples and inferred positions + _h5_file = h5py.File(f'InferredSamples/Chunk{idxChunk}.h5', 'w') + dset_1Photon_clusters = _h5_file.create_dataset( + 'clusters', (0, 5, 5), maxshape=(None, 5, 5), dtype='f4', + chunks=True, compression='gzip' + ) + dset_1photon_label = _h5_file.create_dataset( + 'labels', (0, 4), maxshape=(None, 4), dtype='f4', + chunks=True + ) + _len = dataset.samples.shape[0] + dset_1Photon_clusters.resize((_len, 5, 5)) + dset_1photon_label.resize((_len, 4)) + _chunk_samples = np.zeros(( _len, 5, 5), dtype=np.float32) + _chunk_samples[:, 1:-1, 1:-1] = dataset.samples[:, 0, :, :] + dset_1Photon_clusters[:] = _chunk_samples + + _chunk_predictions = torch.cat(_chunk_predictions, dim=0) + _chunk_predictions = apply_inverse_transforms(_chunk_predictions, numberOfAugOps) + _chunk_labels = np.zeros((_len, 4), dtype=np.float32) + _chunk_labels[:, :2] = _chunk_predictions.numpy() + dset_1photon_label[:] = _chunk_labels + _h5_file.close() + + np.savez(f'InferredSamples/Chunk{idxChunk}.npz', samples=_chunk_samples, labels=_chunk_labels) + predictions = torch.cat(predictions, dim=0) predictions = apply_inverse_transforms(predictions, numberOfAugOps)