Add Inferred sample output
This commit is contained in:
@@ -5,6 +5,8 @@ import models
|
||||
from datasets import *
|
||||
from tqdm import tqdm
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
import h5py
|
||||
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
@@ -85,12 +87,40 @@ if __name__ == "__main__":
|
||||
|
||||
referencePoints.append(dataset.referencePoint)
|
||||
|
||||
_chunk_predictions = []
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataLoader):
|
||||
inputs, _ = batch
|
||||
inputs = inputs.cuda()
|
||||
outputs = model(inputs)[:, :2] # only x and y
|
||||
predictions.append(outputs.cpu())
|
||||
inputs_cuda = inputs.cuda()
|
||||
outputs = model(inputs_cuda)[:, :2].cpu() # only x and y
|
||||
_chunk_predictions.append(outputs)
|
||||
predictions.extend(_chunk_predictions)
|
||||
### save samples and inferred positions
|
||||
_h5_file = h5py.File(f'InferredSamples/Chunk{idxChunk}.h5', 'w')
|
||||
dset_1Photon_clusters = _h5_file.create_dataset(
|
||||
'clusters', (0, 5, 5), maxshape=(None, 5, 5), dtype='f4',
|
||||
chunks=True, compression='gzip'
|
||||
)
|
||||
dset_1photon_label = _h5_file.create_dataset(
|
||||
'labels', (0, 4), maxshape=(None, 4), dtype='f4',
|
||||
chunks=True
|
||||
)
|
||||
_len = dataset.samples.shape[0]
|
||||
dset_1Photon_clusters.resize((_len, 5, 5))
|
||||
dset_1photon_label.resize((_len, 4))
|
||||
_chunk_samples = np.zeros(( _len, 5, 5), dtype=np.float32)
|
||||
_chunk_samples[:, 1:-1, 1:-1] = dataset.samples[:, 0, :, :]
|
||||
dset_1Photon_clusters[:] = _chunk_samples
|
||||
|
||||
_chunk_predictions = torch.cat(_chunk_predictions, dim=0)
|
||||
_chunk_predictions = apply_inverse_transforms(_chunk_predictions, numberOfAugOps)
|
||||
_chunk_labels = np.zeros((_len, 4), dtype=np.float32)
|
||||
_chunk_labels[:, :2] = _chunk_predictions.numpy()
|
||||
dset_1photon_label[:] = _chunk_labels
|
||||
_h5_file.close()
|
||||
|
||||
np.savez(f'InferredSamples/Chunk{idxChunk}.npz', samples=_chunk_samples, labels=_chunk_labels)
|
||||
|
||||
|
||||
predictions = torch.cat(predictions, dim=0)
|
||||
predictions = apply_inverse_transforms(predictions, numberOfAugOps)
|
||||
|
||||
Reference in New Issue
Block a user