Files
Jungfraujoch/image_analysis/ImageAnalysisCPU.cpp
2025-06-10 18:14:04 +02:00

187 lines
7.0 KiB
C++

// SPDX-FileCopyrightText: 2024 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
// SPDX-License-Identifier: GPL-3.0-only
#include "ImageAnalysisCPU.h"
#include "StrongPixelSet.h"
#include "../compression/JFJochDecompress.h"
#include "indexing/IndexerFactory.h"
ImageAnalysisCPU::ImageAnalysisCPU(const DiffractionExperiment &in_experiment,
const AzimuthalIntegration &in_integration,
const PixelMask &in_mask)
: experiment(in_experiment),
integration(in_integration),
npixels(experiment.GetPixelsNum()),
xpixels(experiment.GetXPixelsNum()),
mask_1byte(npixels, 0),
spotFinder(in_integration),
saturation_limit(experiment.GetSaturationLimit()),
integrate(in_experiment),
roi_count(0) {
nquads = 2;
UpdateROI();
for (int i = 0; i < npixels; i++)
mask_1byte[i] = (in_mask.GetMask().at(i) != 0);
indexer = CreateIndexer(experiment);
}
void ImageAnalysisCPU::UpdateROI() {
roi_map = experiment.ExportROIMap();
roi_count = experiment.ROI().size();
roi_names = experiment.ROI().GetROINameMap();
}
void ImageAnalysisCPU::Analyze(DataMessage &output, std::vector<uint8_t> &image, AzimuthalIntegrationProfile &profile, const SpotFindingSettings &spot_finding_settings) {
if ((output.image.GetWidth() != xpixels)
|| (output.image.GetWidth() * output.image.GetHeight() != npixels))
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
"Mismatch in pixel size");
const uint8_t *image_ptr = output.image.GetUncompressedPtr(image);
switch (output.image.GetMode()) {
case CompressedImageMode::Int8:
Analyze<int8_t>(output, image_ptr, INT8_MIN, INT8_MAX, profile, spot_finding_settings);
break;
case CompressedImageMode::Int16:
Analyze<int16_t>(output, image_ptr, INT16_MIN, INT16_MAX, profile, spot_finding_settings);
break;
case CompressedImageMode::Int32:
Analyze<int32_t>(output, image_ptr, INT32_MIN, INT32_MAX, profile, spot_finding_settings);
break;
case CompressedImageMode::Uint8:
Analyze<uint8_t>(output, image_ptr, UINT8_MAX, UINT8_MAX, profile, spot_finding_settings);
break;
case CompressedImageMode::Uint16:
Analyze<uint16_t>(output, image_ptr, UINT16_MAX, UINT16_MAX, profile, spot_finding_settings);
break;
case CompressedImageMode::Uint32:
Analyze<uint32_t>(output, image_ptr, UINT32_MAX, UINT32_MAX, profile, spot_finding_settings);
break;
default:
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "RGB/float mode not supported");
}
}
template <class T>
void ImageAnalysisCPU::Analyze(DataMessage &output,
const uint8_t *in_image,
T err_pixel_val,
T sat_pixel_val,
AzimuthalIntegrationProfile &profile,
const SpotFindingSettings &spot_finding_settings) {
auto image = (T *) in_image;
std::vector<ROIMessage> roi(roi_count);
size_t err_pixels = 0;
size_t masked_pixels = 0;
size_t sat_pixels = 0;
int64_t min_value = INT64_MAX;
int64_t max_value = INT64_MIN;
if (sat_pixel_val > saturation_limit)
sat_pixel_val = static_cast<T>(saturation_limit);
auto &pixel_to_bin = integration.GetPixelToBin();
auto &corrections = integration.Corrections();
auto nbins = integration.GetBinNumber();
std::vector<int32_t> updated_image(experiment.GetPixelsNum());
std::vector<float> sum(nbins);
std::vector<float> sum2(nbins);
std::vector<uint32_t> count(nbins);
for (int i = 0; i < npixels; i++) {
auto bin = pixel_to_bin[i];
auto value = image[i] * corrections[i];
if (mask_1byte[i] != 0) {
updated_image[i] = INT32_MIN;
++masked_pixels;
} else if (image[i] >= sat_pixel_val) {
updated_image[i] = INT32_MIN;
++sat_pixels;
} else if (std::is_signed<T>::value && (image[i] == err_pixel_val)) {// Error pixels are possible only for signed types
updated_image[i] = INT32_MIN;
++err_pixels;
} else {
updated_image[i] = static_cast<int32_t>(image[i]);
if (image[i] > max_value)
max_value = image[i];
if (image[i] < min_value)
min_value = image[i];
if (roi_count > 0 && (roi_map[i] != 0)) {
int64_t x = i % xpixels;
int64_t y = i / xpixels;
for (int8_t r = 0; r < roi_count; r++) {
if ((roi_map[i] & (1<<r)) != 0) {
roi[r].sum += image[i];
roi[r].sum_square += image[i] * image[i];
roi[r].pixels += 1;
if (image[i] > roi[r].max_count)
roi[r].max_count = image[i];
roi[r].x_weighted += x * image[i];
roi[r].y_weighted += y * image[i];
}
}
}
if (bin < nbins) {
sum[bin] += value;
sum2[bin] += value * value;
count[bin] += 1;
}
}
}
profile.Clear(integration);
profile.Add(sum, count);
std::vector<DiffractionSpot> spots;
if (spot_finding_settings.enable)
spots = spotFinder.Run(image, spot_finding_settings);
std::vector<DiffractionSpot> spots_out;
FilterSpotsByCount(max_spot_count, spots, spots_out);
for (const auto &spot: spots_out)
output.spots.push_back(spot);
if (indexer && spot_finding_settings.indexing) {
auto latt = indexer->Run(output, spots_out);
if (latt && spot_finding_settings.quick_integration) {
auto res = integrate.Integrate(
CompressedImage(updated_image, experiment.GetXPixelsNum(),
experiment.GetYPixelsNum()), latt.value(), spot_finding_settings.quick_integration_d_min_A);
output.reflections = res.reflections;
output.b_factor = res.b_factor;
}
}
output.max_viable_pixel_value = max_value;
output.min_viable_pixel_value = min_value;
output.error_pixel_count = err_pixels;
output.saturated_pixel_count = sat_pixels;
output.az_int_profile = profile.GetResult();
output.bkg_estimate = profile.GetBkgEstimate(integration.Settings());
if ((inference_client != nullptr) && spot_finding_settings.resolution_estimate)
output.resolution_estimate = inference_client->Inference(experiment, image, nquads);
for (const auto &[key, val]: roi_names)
output.roi[key] = roi[val];
}
ImageAnalysisCPU &ImageAnalysisCPU::NeuralNetInference(NeuralNetInferenceClient *client) {
inference_client = client;
return *this;
}