Chunkize the inference
This commit is contained in:
@@ -6,10 +6,16 @@ from datasets import *
|
||||
from tqdm import tqdm
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
configs = {}
|
||||
configs['SiemenStar'] = {
|
||||
'dataFiles': [f'/mnt/sls_det_storage/moench_data/MLXID/Samples/Measurement/2504_SOLEIL_SiemenStarClusters_MOENCH040_150V/clusters_chunk{i}.h5' for i in range(32)],
|
||||
'modelVersion': '251020',
|
||||
'dataFiles': [f'/mnt/sls_det_storage/moench_data/MLXID/Samples/Measurement/2504_SOLEIL_SiemenStarClusters_MOENCH040_150V/clusters_chunk{i}.h5' for i in range(200)],
|
||||
'modelVersion': '251022',
|
||||
'roi': [140, 230, 120, 210], # x_min, x_max, y_min, y_max,
|
||||
'noise': 0.13 # in keV
|
||||
}
|
||||
@@ -25,28 +31,36 @@ if __name__ == "__main__":
|
||||
config = configs[task]
|
||||
|
||||
model = models.get_model_class(config['modelVersion'])().cuda()
|
||||
model.load_state_dict(torch.load(f'/home/xie_x1/MLXID/DeepLearning/Models/singlePhotonNet_Noise{config["noise"]}keV_{config["modelVersion"]}.pth', weights_only=True))
|
||||
|
||||
dataset = singlePhotonDataset(config['dataFiles'], sampleRatio=1.0, datasetName='Inference')
|
||||
dataLoader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=4096,
|
||||
shuffle=False,
|
||||
num_workers=16,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
referencePoints = dataset.referencePoint
|
||||
model.load_state_dict(torch.load(f'/home/xie_x1/MLXID/DeepLearning/Models/singlePhoton{config["modelVersion"]}_15.3keV_Noise{config["noise"]}keV_E500.pth', weights_only=True))
|
||||
predictions = []
|
||||
referencePoints = []
|
||||
nChunks = len(config['dataFiles']) // 32 + 1
|
||||
for idxChunk in range(nChunks):
|
||||
stFileIdx = idxChunk * 32
|
||||
edFileIdx = min((idxChunk + 1) * 32, len(config['dataFiles']))
|
||||
sampleFiles = config['dataFiles'][stFileIdx : edFileIdx]
|
||||
print(f'Processing files {stFileIdx} to {edFileIdx}...')
|
||||
dataset = singlePhotonDataset(sampleFiles, sampleRatio=1.0, datasetName='Inference')
|
||||
dataLoader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=8192,
|
||||
shuffle=False,
|
||||
num_workers=16,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataLoader):
|
||||
inputs, _ = batch
|
||||
inputs = inputs.cuda()
|
||||
outputs = model(inputs)
|
||||
predictions.append(outputs.cpu())
|
||||
referencePoints.append(dataset.referencePoint)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataLoader):
|
||||
inputs, _ = batch
|
||||
inputs = inputs.cuda()
|
||||
outputs = model(inputs)[:, :2] # only x and y
|
||||
predictions.append(outputs.cpu())
|
||||
|
||||
predictions = torch.cat(predictions, dim=0)
|
||||
predictions += torch.tensor([1.5, 1.5]).unsqueeze(0) # adjust back to original coordinate system
|
||||
referencePoints = np.concatenate(referencePoints, axis=0)
|
||||
print(f'mean x = {torch.mean(predictions[:, 0])}, std x = {torch.std(predictions[:, 0])}')
|
||||
print(f'mean y = {torch.mean(predictions[:, 1])}, std y = {torch.std(predictions[:, 1])}')
|
||||
absolutePositions = predictions.numpy() + referencePoints[:, :2] - 1
|
||||
|
||||
Reference in New Issue
Block a user