Correct double photon sample (remove ordering)
This commit is contained in:
+24
-11
@@ -101,10 +101,12 @@ class singlePhotonDataset(Dataset):
|
||||
return self.effectiveLength
|
||||
|
||||
class doublePhotonDataset(Dataset):
|
||||
def __init__(self, sampleList, sampleRatio, datasetName, reuselFactor=1):
|
||||
def __init__(self, sampleList, sampleRatio, datasetName, reuselFactor=1, noiseKeV=0):
|
||||
self.sampleFileList = sampleList
|
||||
self.sampleRatio = sampleRatio
|
||||
self.datasetName = datasetName
|
||||
self.noiseKeV = noiseKeV
|
||||
self._init_coords()
|
||||
|
||||
all_samples = []
|
||||
all_labels = []
|
||||
@@ -113,11 +115,23 @@ class doublePhotonDataset(Dataset):
|
||||
all_samples.append(data['samples'])
|
||||
all_labels.append(data['labels'])
|
||||
self.samples = np.concatenate(all_samples, axis=0)
|
||||
if self.noiseKeV != 0:
|
||||
print(f'Adding Gaussian noise with sigma = {self.noiseKeV} keV to samples in {self.datasetName} dataset')
|
||||
noise = np.random.normal(loc=0.0, scale=self.noiseKeV, size=self.samples.shape)
|
||||
self.samples = self.samples + noise
|
||||
self.labels = np.concatenate(all_labels, axis=0)
|
||||
|
||||
### total number of samples
|
||||
self.length = int(self.samples.shape[0] * self.sampleRatio) // 2 * reuselFactor
|
||||
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}")
|
||||
|
||||
def _init_coords(self):
|
||||
# Create a coordinate grid for 3x3 input
|
||||
x = np.linspace(0, 5, 6)
|
||||
y = np.linspace(0, 5, 6)
|
||||
x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (6,6), (6,6)
|
||||
self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, 6, 6)
|
||||
self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, 6, 6)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = np.zeros((8, 8), dtype=np.float32)
|
||||
@@ -128,21 +142,20 @@ class doublePhotonDataset(Dataset):
|
||||
|
||||
### random position for photons in
|
||||
pos_x1 = np.random.randint(1, 4)
|
||||
pos_y1 = np.random.randint(1, 4)
|
||||
# pos_y1 = np.random.randint(1, 4)
|
||||
pos_y1 = pos_x1
|
||||
sample[pos_y1:pos_y1+5, pos_x1:pos_x1+5] += photon1
|
||||
pos_x2 = np.random.randint(1, 4)
|
||||
pos_y2 = np.random.randint(1, 4)
|
||||
# pos_y2 = np.random.randint(1, 4)
|
||||
pos_y2 = pos_x2
|
||||
sample[pos_y2:pos_y2+5, pos_x2:pos_x2+5] += photon2
|
||||
sample = sample[1:-1, 1:-1] ### sample size: 6x6
|
||||
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0) / 1000. ### to keV
|
||||
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)
|
||||
sample = torch.cat((sample, self.x_grid, self.y_grid), dim=0) ### concatenate coordinate channels
|
||||
|
||||
label1 = self.labels[idx1] + np.array([pos_x1, pos_y1, 0, 0]) - 1
|
||||
label2 = self.labels[idx2] + np.array([pos_x2, pos_y2, 0, 0]) - 1
|
||||
# label = np.concatenate((label1, label2), axis=0)
|
||||
if label1[0] < label2[0]:
|
||||
label = np.concatenate((label1, label2), axis=0)
|
||||
else:
|
||||
label = np.concatenate((label2, label1), axis=0)
|
||||
label1 = self.labels[idx1] + np.array([pos_x1-1, pos_y1-1, 0, 0])
|
||||
label2 = self.labels[idx2] + np.array([pos_x2-1, pos_y2-1, 0, 0])
|
||||
label = np.concatenate((label1, label2), axis=0)
|
||||
return sample, torch.tensor(label, dtype=torch.float32)
|
||||
|
||||
def __len__(self):
|
||||
|
||||
+58
-5
@@ -10,6 +10,13 @@ def get_model_class(version):
|
||||
raise ValueError(f"Model class '{class_name}' not found.")
|
||||
return cls
|
||||
|
||||
def get_double_photon_model_class(version):
|
||||
class_name = f'doublePhotonNet_{version}'
|
||||
cls = globals().get(class_name)
|
||||
if cls is None:
|
||||
raise ValueError(f"Model class '{class_name}' not found.")
|
||||
return cls
|
||||
|
||||
class singlePhotonNet_250909(nn.Module):
|
||||
def weight_init(self):
|
||||
for m in self.modules():
|
||||
@@ -181,15 +188,11 @@ class doublePhotonNet_251001(nn.Module):
|
||||
coords = self.fc(x)*6
|
||||
return coords # shape: [B, 4]
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class doublePhotonNet_251001_2(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Backbone: deeper + residual-like blocks
|
||||
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
|
||||
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
|
||||
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
||||
|
||||
@@ -253,4 +256,54 @@ class doublePhotonNet_251001_2(nn.Module):
|
||||
|
||||
# Regression
|
||||
coords = self.fc(global_feat) * 6.0 # scale to [0,6)
|
||||
return coords # [B, 4]
|
||||
|
||||
class doublePhotonNet_251104(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Backbone: deeper
|
||||
self.conv1 = nn.Conv2d(3, 32, kernel_size=3) # 6x6 -> 4x4
|
||||
self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # 4x4 -> 2x2
|
||||
self.conv3 = nn.Conv2d(64, 128, kernel_size=2) # 2x2 -> 1x1
|
||||
|
||||
# Spatial Attention Module (轻量但有效)
|
||||
self.spatial_attn = nn.Sequential(
|
||||
nn.Conv2d(128, 1, kernel_size=1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# Enhanced regression head
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(128, 256), # concat max + avg
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(128, 4),
|
||||
nn.Sigmoid() # output in [0,1]
|
||||
)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
# Feature extraction
|
||||
c1 = F.relu(self.conv1(x)) # [B, 32, 6, 6]
|
||||
c2 = F.relu(self.conv2(c1)) # [B, 64, 6, 6]
|
||||
c3 = F.relu(self.conv3(c2)) # [B,128, 6, 6]
|
||||
|
||||
# Spatial attention: highlight photon peaks
|
||||
attn = self.spatial_attn(c3) # [B, 1, 6, 6]
|
||||
c3 = c3 * attn # reweight features
|
||||
|
||||
# Regression
|
||||
coords = self.fc(c3.flatten(1)) * 6.0 # scale to [0,6)
|
||||
return coords # [B, 4]
|
||||
Reference in New Issue
Block a user