Files
Jungfraujoch/tests/FPGASpotFindingUnitTest.cpp
2025-03-02 13:15:28 +01:00

172 lines
5.1 KiB
C++

// SPDX-FileCopyrightText: 2024 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
// SPDX-License-Identifier: GPL-3.0-only
#include <iostream>
#include <catch2/catch_all.hpp>
#include "../fpga/hls_simulation/hls_cores.h"
TEST_CASE("FPGA_spot_finder_core","[FPGA][SpotFinder]") {
STREAM_768 input;
STREAM_768 output;
hls::stream<ap_axiu<32,1,1,1>> strong_pixel;
hls::stream<ap_axiu<32,1,1,1>> mask_in;
hls::stream<ap_axiu<32,1,1,1>> mask_out;
ap_int<32> in_photon_count_threshold = 8;
float_uint32 in_strong_pixel_threshold;
in_strong_pixel_threshold.f = 4.0;
std::vector<int32_t> input_frame(RAW_MODULE_SIZE), output_frame(RAW_MODULE_SIZE);
for (int i = 0; i < RAW_MODULE_SIZE; i++) {
if (i % RAW_MODULE_COLS == 1023)
input_frame[i] = INT24_MIN;
else
input_frame[i] = i % RAW_MODULE_COLS;
}
input << packet_768_t{.user = 0};
for (int i = 0; i < RAW_MODULE_SIZE * sizeof(uint16_t) / 64; i++) {
ap_int<24> tmp[32];
for (int j = 0; j < 32; j++)
tmp[j] = input_frame[i * 32 + j];
input << packet_768_t{.data = pack32(tmp), .user = 0};
mask_in << ap_axiu<32,1,1,1>{.data = UINT32_MAX, .user = 0};
}
mask_in << ap_axiu<32,1,1,1>{.data = 0, .user =1};
input << packet_768_t{.user = 1};
ap_uint<32> tmp_strong_pixel_threshold = in_strong_pixel_threshold.u;
spot_finder(input,
mask_in,
output,
mask_out,
strong_pixel,
in_photon_count_threshold,
tmp_strong_pixel_threshold);
REQUIRE(input.size() == 0);
REQUIRE(output.size() == RAW_MODULE_SIZE * sizeof(uint16_t) / 64 + 2);
REQUIRE(strong_pixel.size() == RAW_MODULE_SIZE * sizeof(uint16_t) / 64 + 16 + 1);
REQUIRE(mask_out.size() == RAW_MODULE_SIZE * sizeof(uint16_t) / 64 + 1);
}
void spot_finder_prepare(hls::stream<spot_finder_packet> &data_in,
hls::stream<spot_finder_packet> &data_out,
hls::stream<ap_int<SUM_BITWIDTH>> &sum_out,
hls::stream<ap_int<SUM2_BITWIDTH>> &sum2_out,
hls::stream<ap_int<VALID_BITWIDTH>> &valid_out);
TEST_CASE("FPGA_spot_finder_prepare", "FPGA") {
hls::stream<spot_finder_packet> data_in;
hls::stream<spot_finder_packet> data_out;
hls::stream<ap_int<SUM_BITWIDTH>> sum_out;
hls::stream<ap_int<SUM2_BITWIDTH>> sum2_out;
hls::stream<ap_int<VALID_BITWIDTH>> valid_out;
// Empty packet, just forward
data_in << spot_finder_packet{
.data = 348,
.user = 0
};
std::vector<int32_t> input_frame(RAW_MODULE_SIZE);
std::vector<uint8_t> mask(RAW_MODULE_SIZE);
for (int i = 0; i < RAW_MODULE_SIZE; i++) {
input_frame[i] = i;
mask[i] = 1;
}
for (int i = 0; i < RAW_MODULE_SIZE * sizeof(uint16_t) / 64; i++) {
ap_int<24> tmp[32];
ap_uint<32> mask_tmp;
for (int j = 0; j < 32; j++) {
tmp[j] = input_frame[i * 32 + j];
mask_tmp[j] = (mask[i * 32 + j] != 0);
}
data_in << spot_finder_packet{
.data = pack32(tmp),
.mask = mask_tmp,
.user = 0,
};
}
data_in << spot_finder_packet{.user = 1};
spot_finder_prepare(data_in,
data_out,
sum_out,
sum2_out,
valid_out);
REQUIRE(data_in.empty());
REQUIRE(data_out.size() == 16384 + 2);
REQUIRE(sum_out.size() == 16384);
REQUIRE(sum2_out.size() == 16384);
REQUIRE(valid_out.size() == 16384);
int64_t sum[16384];
int64_t sum2[16384];
int64_t valid[16384];
for (int i = 0; i < 16384; i++) {
sum[i] = sum_out.read().to_int64();
sum2[i] = sum2_out.read().to_int64();
valid[i] = valid_out.read().to_int64();
}
REQUIRE(sum_out.empty());
REQUIRE(sum2_out.empty());
REQUIRE(valid_out.empty());
int sum_err = 0;
int sum2_err = 0;
int valid_err = 0;
for (int i = 0; i < 16384; i++) {
int64_t sum_ref = 0;
int64_t sum2_ref = 0;
int64_t valid_ref = 0;
int32_t col = i % 32;
int32_t line = i / 32;
for (int l = line - 15; l <= line + 15; l++) {
if (l >= 0 && l < 512) {
for (int c = 0; c < 32; c++) {
int64_t val = input_frame[(col * 32 + c) + l * RAW_MODULE_COLS];
auto mask_val = mask[(col * 32 + c) + l * RAW_MODULE_COLS];
valid_ref += mask_val;
if (mask_val == 0)
val = 0;
sum_ref += val;
sum2_ref += val * val;
}
}
}
if (valid_ref != valid[i])
valid_err++;
if (sum2_ref != sum2[i])
sum2_err++;
if (sum_ref != sum[i]) {
std::cout << i << " " << sum_ref << " " << sum[i] << " " << valid[i] << " " << valid_ref << std::endl;
sum_err++;
}
}
CHECK(sum_err == 0);
CHECK(sum2_err == 0);
CHECK(valid_err == 0);
}