NeuralNetInferenceClient: Accept PixelMask

This commit is contained in:
2025-06-30 21:29:48 +02:00
parent dba2544c48
commit 92288c60d7
7 changed files with 110 additions and 37 deletions

View File

@@ -65,7 +65,10 @@ void NeuralNetInferenceClient::AddHost(std::string addr) {
template<class T>
std::vector<float> NeuralNetInferenceClient::PrepareInternal(const DiffractionExperiment& experiment, const T* image, Quarter q) {
std::vector<float> NeuralNetInferenceClient::PrepareInternal(const DiffractionExperiment& experiment,
const PixelMask& mask,
const T* image,
Quarter q) {
std::vector<float> ret(512*512);
int64_t pool_factor = GetMaxPoolFactor(experiment);
@@ -102,31 +105,42 @@ std::vector<float> NeuralNetInferenceClient::PrepareInternal(const DiffractionEx
for (int64_t yp = y0; yp < max_yp; yp++) {
for (int64_t xp = x0; xp < max_xp; xp++) {
int64_t pxl = image[yp * xpixel + xp];
if (pxl > INT16_MAX)
if (mask.GetMask().at(yp * xpixel + xp) != 0)
pxl = INT64_MAX;
else if (pxl > INT16_MAX)
pxl = INT16_MAX;
if (val < pxl)
if (pxl > val)
val = pxl;
}
}
float max_pool = floor(sqrt(static_cast<double>(val)));
ret[512 * y + x] = max_pool;
if (val == INT64_MAX)
ret[512 * y + x] = 0;
else
ret[512 * y + x] = floor(sqrt(static_cast<double>(val)));
}
}
return ret;
}
std::vector<float> NeuralNetInferenceClient::Prepare(const DiffractionExperiment& experiment, const int16_t *image, Quarter q) {
return PrepareInternal(experiment, image, q);
std::vector<float> NeuralNetInferenceClient::Prepare(const DiffractionExperiment& experiment,
const PixelMask& mask,
const int16_t *image, Quarter q) {
return PrepareInternal(experiment, mask, image, q);
}
std::vector<float> NeuralNetInferenceClient::Prepare(const DiffractionExperiment& experiment, const int32_t *image, Quarter q) {
return PrepareInternal(experiment, image, q);
std::vector<float> NeuralNetInferenceClient::Prepare(const DiffractionExperiment& experiment,
const PixelMask& mask,
const int32_t *image, Quarter q) {
return PrepareInternal(experiment, mask, image, q);
}
std::vector<float> NeuralNetInferenceClient::Prepare(const DiffractionExperiment& experiment, const int8_t *image, Quarter q) {
return PrepareInternal(experiment, image, q);
std::vector<float> NeuralNetInferenceClient::Prepare(const DiffractionExperiment& experiment,
const PixelMask& mask,
const int8_t *image, Quarter q) {
return PrepareInternal(experiment, mask, image, q);
}
size_t NeuralNetInferenceClient::GetMaxPoolFactor(const DiffractionExperiment& experiment) const {
@@ -206,20 +220,23 @@ std::optional<float> NeuralNetInferenceClient::Run(const std::vector<float> &inp
}
std::optional<float>
NeuralNetInferenceClient::Inference(const DiffractionExperiment &experiment, const void *image, int nquads) {
NeuralNetInferenceClient::Inference(const DiffractionExperiment &experiment,
const PixelMask& mask,
const void *image,
int nquads) {
if (!enable)
return {};
std::optional<float> quad[4];
if (nquads >= 1)
quad[0] = Inference(experiment, image, Quarter::BottomRight);
quad[0] = Inference(experiment, mask, image, Quarter::BottomRight);
if (nquads >= 2)
quad[1] = Inference(experiment, image, Quarter::BottomRight);
quad[1] = Inference(experiment, mask, image, Quarter::BottomRight);
if (nquads >= 3)
quad[2] = Inference(experiment, image, Quarter::BottomRight);
quad[2] = Inference(experiment, mask, image, Quarter::BottomRight);
if (nquads >= 4)
quad[3] = Inference(experiment, image, Quarter::BottomLeft);
quad[3] = Inference(experiment, mask, image, Quarter::BottomLeft);
int count = 0;
float sum = 0.0f;
@@ -234,7 +251,10 @@ NeuralNetInferenceClient::Inference(const DiffractionExperiment &experiment, con
return sum / count;
}
std::optional<float> NeuralNetInferenceClient::Inference(const DiffractionExperiment& experiment, const void *image, Quarter q) {
std::optional<float> NeuralNetInferenceClient::Inference(const DiffractionExperiment& experiment,
const PixelMask& mask,
const void *image,
Quarter q) {
if (!enable)
return {};
@@ -243,21 +263,21 @@ std::optional<float> NeuralNetInferenceClient::Inference(const DiffractionExperi
switch (experiment.GetByteDepthImage()) {
case 1:
if (experiment.IsPixelSigned())
v = PrepareInternal(experiment, (int8_t *) image, q);
v = PrepareInternal(experiment, mask, (int8_t *) image, q);
else
v = PrepareInternal(experiment, (uint8_t *) image, q);
v = PrepareInternal(experiment, mask, (uint8_t *) image, q);
break;
case 2:
if (experiment.IsPixelSigned())
v = PrepareInternal(experiment, (int16_t *) image, q);
v = PrepareInternal(experiment, mask, (int16_t *) image, q);
else
v = PrepareInternal(experiment, (uint16_t *) image, q);
v = PrepareInternal(experiment, mask, (uint16_t *) image, q);
break;
case 4:
if (experiment.IsPixelSigned())
v = PrepareInternal(experiment, (int32_t *) image, q);
v = PrepareInternal(experiment, mask, (int32_t *) image, q);
else
v = PrepareInternal(experiment, (uint32_t *) image, q);
v = PrepareInternal(experiment, mask, (uint32_t *) image, q);
break;
default:
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Bit depth not supported");