Files
Jungfraujoch/fpga/hls/spot_finder.cpp

271 lines
9.8 KiB
C++

// Copyright (2019-2023) Paul Scherrer Institute
#include "hls_jfjoch.h"
#define SUM_BITWIDTH 29 // 16 + 12 + 1
#define SUM2_BITWIDTH 45 // 32 + 12 + 1
#define VALID_BITWIDTH 12 // 12
#ifdef JFJOCH_HLS_NOSYNTH
#include <thread>
#endif
ap_uint<32> count_pixels(ap_uint<32> &in) {
#pragma HLS INLINE
ap_uint<32> ret = 0;
for (int i = 0; i < 32; i++)
ret += in[i];
return ret;
}
struct spot_finder_packet {
ap_uint<512> data;
ap_uint<1> user;
ap_uint<1> last;
};
void spot_finder_in_stream(STREAM_512 &data_in,
hls::stream<spot_finder_packet> &data_out) {
packet_512_t packet_in;
data_in >> packet_in;
data_out << spot_finder_packet{.data = packet_in.data, .user = packet_in.user, .last = packet_in.last};
data_in >> packet_in;
while (!packet_in.user) {
#pragma HLS PIPELINE II=1
data_out << spot_finder_packet{.data = packet_in.data, .user = packet_in.user, .last = packet_in.last};
data_in >> packet_in;
}
data_out << spot_finder_packet{.data = packet_in.data, .user = packet_in.user, .last = packet_in.last};
}
ap_uint<32> calc_mask(ap_int<16> val[32]) {
#pragma HLS PIPELINE II=1
ap_uint<32> ret = 0;
for (int i = 0; i < 32; i++) {
if ((val[i] == INT16_MAX) || (val[i] == INT16_MIN))
ret[i] = 0;
else
ret[i] = 1;
}
return ret;
}
ap_int<SUM_BITWIDTH> calc_sum(ap_int<16> val[32], ap_uint<32> mask) {
#pragma HLS PIPELINE II=1
ap_int<SUM_BITWIDTH> ret = 0;
for (int i = 0; i < 32; i++) {
if (mask[i])
ret += val[i];
}
return ret;
}
ap_int<SUM2_BITWIDTH> calc_sum2(ap_int<16> val[32], ap_uint<32> mask) {
#pragma HLS PIPELINE II=1
ap_int<SUM2_BITWIDTH> ret = 0;
for (int i = 0; i < 32; i++) {
if (mask[i])
ret += val[i] * val[i];
}
return ret;
}
ap_int<VALID_BITWIDTH> calc_valid(ap_uint<32> mask) {
#pragma HLS PIPELINE II=1
ap_int<VALID_BITWIDTH> ret = 0;
for (int i = 0; i < 32; i++) {
ret += mask[i];
}
return ret;
}
void spot_finder_prepare(hls::stream<spot_finder_packet> &data_in,
hls::stream<spot_finder_packet> &data_out,
hls::stream<ap_int<SUM_BITWIDTH>> &sum_out,
hls::stream<ap_int<SUM2_BITWIDTH>> &sum2_out,
hls::stream<ap_int<VALID_BITWIDTH>> &valid_out) {
spot_finder_packet packet;
data_in >> packet;
data_out << packet;
ap_int<SUM_BITWIDTH> sum[32];
ap_int<SUM2_BITWIDTH> sum2[32];
ap_int<VALID_BITWIDTH> valid[32];
for (int col = 0; col < 32; col++) {
sum[col] = 0;
sum2[col] = 0;
valid[col] = 0;
}
data_in >> packet;
while (!packet.user) {
for (int i = 0; i < RAW_MODULE_SIZE * sizeof(uint16_t) / 64; i++) {
#pragma HLS PIPELINE II=1
data_out << packet;
ap_int<16> val[32];
unpack32(packet.data, val);
ap_uint<32> mask = calc_mask(val);
if ((i / 32) % 32 == 0) {
sum[i % 32] = calc_sum(val, mask);
sum2[i % 32] = calc_sum2(val, mask);
valid[i % 32] = calc_valid(mask);
} else if ((i / 32) % 32 == 31) {
sum_out << sum[i % 32] + calc_sum(val, mask);
sum2_out << sum2[i % 32] + calc_sum2(val, mask);
valid_out << valid[i % 32] + calc_valid(mask);
} else {
sum[i % 32] += calc_sum(val, mask);
sum2[i % 32] += calc_sum2(val, mask);
valid[i % 32] += calc_valid(mask);
}
data_in >> packet;
}
}
data_out << packet;
}
ap_uint<32> spot_finder_snr_threshold(ap_int<16> val[32],
ap_uint<16> snr_threshold_2,
ap_int<SUM_BITWIDTH> sum,
ap_int<SUM2_BITWIDTH> sum2,
ap_int<VALID_BITWIDTH> valid_count) {
#pragma HLS PIPELINE II=1
if (snr_threshold_2 == 0)
return UINT32_MAX;
ap_int<SUM2_BITWIDTH+12> variance = valid_count * sum2 - sum * sum; // This is variance * valid_count^2
ap_int<SUM2_BITWIDTH+12+16> threshold = ((variance * snr_threshold_2 + 8) / (4*4));
// snr_threshold is in units of 0.25
ap_uint<32> ret = 0;
for (int j = 0; j < 32; j++) {
ap_int<SUM_BITWIDTH+1> in_minus_mean = val[j] * valid_count - sum; // This is (pxl - mean) * valid_count
// Aim is to compare pxl-mean with sqrt(variance) * threshold
// however this would require sqrt and divisions, so
// it is cheaper to compare ((pxl-mean) * valid_count)^2 with variance * valid_count^2 * threshold^2,
// but need to make sure that (pxl - mean) is positive
// Also assume that N ≈ (N-1)
if ((in_minus_mean * in_minus_mean > threshold) &&
(in_minus_mean > 0) &&
(valid_count > 32 * 32 / 2)) // at least half of the pixels
ret[j] = 1;
else
ret[j] = 0;
}
return ret;
}
ap_uint<32> spot_finder_count_threshold(ap_int<16> val[32],
ap_int<16> &count_threshold) {
#pragma HLS PIPELINE II=1
if (count_threshold <= 0)
return UINT32_MAX;
ap_uint<32> ret = 0;
for (int j = 0; j < 32; j++) {
if (val[j] >= count_threshold)
ret[j] = 1;
else
ret[j] = 0;
}
return ret;
}
void spot_finder_apply_threshold(hls::stream<spot_finder_packet> &data_in,
STREAM_512 &data_out,
hls::stream<ap_int<SUM_BITWIDTH>> &sum_in,
hls::stream<ap_int<SUM2_BITWIDTH>> &sum2_in,
hls::stream<ap_int<VALID_BITWIDTH>> &valid_in,
hls::stream<ap_axiu<32,1,1,1>> &strong_pixel_out,
volatile ap_int<16> &in_count_threshold,
volatile ap_uint<8> &in_snr_threshold) {
spot_finder_packet packet_in;
data_in >> packet_in;
data_out << packet_512_t{.data = packet_in.data, .user = packet_in.user, .last = packet_in.last};
ap_int<SUM_BITWIDTH> sum[32];
ap_int<SUM2_BITWIDTH> sum2[32];
ap_int<VALID_BITWIDTH> valid[32];
data_in >> packet_in;
while (!packet_in.user) {
ap_int<16> count_threshold = in_count_threshold;
ap_uint<8> snr_threshold = in_snr_threshold;
ap_uint<16> snr_threshold_2 = snr_threshold * snr_threshold;
ap_uint<32> strong_pixel_count = 0;
for (int i = 0; i < RAW_MODULE_SIZE * sizeof(uint16_t) / 64; i++) {
#pragma HLS PIPELINE II=1
if ((i / 32) % 32 == 0) {
sum_in >> sum[i % 32];
sum2_in >> sum2[i % 32];
valid_in >> valid[i % 32];
}
data_out << packet_512_t{.data = packet_in.data, .user = packet_in.user, .last = packet_in.last};
ap_int<16> data_unpacked[32];
unpack32(packet_in.data, data_unpacked);
ap_uint<32> strong_pixel = spot_finder_count_threshold(data_unpacked, count_threshold) &
spot_finder_snr_threshold(data_unpacked, snr_threshold_2,
sum[i % 32], sum2[i % 32], valid[i % 32]);
strong_pixel_out << ap_axiu<32,1,1,1>{.data = strong_pixel, .user = 0};
strong_pixel_count += count_pixels(strong_pixel);
data_in >> packet_in;
}
// Save module statistics
strong_pixel_out << ap_axiu<32,1,1,1>{.data = count_threshold, .user = 0};
strong_pixel_out << ap_axiu<32,1,1,1>{.data = snr_threshold, .user = 0};
strong_pixel_out << ap_axiu<32,1,1,1>{.data = strong_pixel_count, .user = 0};
for (int i = 0; i < 13; i++)
strong_pixel_out << ap_axiu<32,1,1,1>{.data = 0, .user = 0};
}
strong_pixel_out << ap_axiu<32,1,1,1>{.data = 0, .user = 1};
data_out << packet_512_t{.data = packet_in.data, .user = packet_in.user, .last = packet_in.last};
}
void spot_finder(STREAM_512 &data_in,
STREAM_512 &data_out,
hls::stream<ap_axiu<32,1,1,1>> &strong_pixel_out,
volatile ap_int<16> &in_count_threshold,
volatile ap_uint<8> &in_snr_threshold) {
#pragma HLS INTERFACE axis port=data_in
#pragma HLS INTERFACE axis port=data_out
#pragma HLS INTERFACE axis port=strong_pixel_out
#pragma HLS INTERFACE ap_none register port=in_count_threshold
#pragma HLS INTERFACE ap_none register port=in_snr_threshold
#pragma HLS DATAFLOW
hls::stream<spot_finder_packet, 2> data_0;
hls::stream<spot_finder_packet, 1080> data_1;
#pragma HLS BIND_STORAGE variable=data_1 type=fifo impl=bram
hls::stream<ap_int<SUM_BITWIDTH>, 24> sum_0;
hls::stream<ap_int<SUM2_BITWIDTH>, 24> sum2_0;
hls::stream<ap_int<VALID_BITWIDTH>, 24> valid_0;
#ifndef JFJOCH_HLS_NOSYNTH
spot_finder_in_stream(data_in, data_0);
spot_finder_prepare(data_0, data_1, sum_0, sum2_0, valid_0);
spot_finder_apply_threshold(data_1, data_out, sum_0, sum2_0, valid_0, strong_pixel_out,
in_count_threshold, in_snr_threshold);
#else
std::vector<std::thread> spot_finder_cores;
spot_finder_cores.emplace_back([&] {spot_finder_in_stream(data_in, data_0);});
spot_finder_cores.emplace_back([&] {spot_finder_prepare(data_0, data_1, sum_0, sum2_0, valid_0);});
spot_finder_cores.emplace_back([&] {spot_finder_apply_threshold(data_1, data_out, sum_0, sum2_0, valid_0,
strong_pixel_out, in_count_threshold,
in_snr_threshold);});
for (auto &i : spot_finder_cores)
i.join();
#endif
}