Files
glocalize/SkullRemovalService.cpp
2026-01-07 14:22:54 +01:00

165 lines
6.6 KiB
C++

#include "SkullRemovalService.h"
#include "connectITKVTK.h"
#include <vtkImageData.h>
#include <vtkImageExport.h>
#include <vtkImageImport.h>
#include <vtkSmartPointer.h>
#include <itkBinaryDilateImageFilter.h>
#include <itkBinaryErodeImageFilter.h>
#include <itkBinaryThresholdImageFilter.h>
#include <itkCastImageFilter.h>
#include <itkConnectedThresholdImageFilter.h>
#include <itkCurvatureFlowImageFilter.h>
#include <itkImage.h>
#include <itkMaskNegatedImageFilter.h>
#include <itkVTKImageExport.h>
#include <itkVTKImageImport.h>
#include <itkFlatStructuringElement.h>
#include <itkMacro.h>
#include <stdexcept>
namespace
{
// VTK CT data in this codebase is typically signed 16-bit (short).
// Import using the exact scalar type to avoid itk::VTKImageImport runtime errors,
// then cast to float for the processing steps that benefit from it.
using InputPixelType = short;
using InternalPixelType = float;
constexpr unsigned int Dimension = 3;
using InputImageType = itk::Image<InputPixelType, Dimension>;
using InternalImageType = itk::Image<InternalPixelType, Dimension>;
using MaskPixelType = unsigned char;
using MaskImageType = itk::Image<MaskPixelType, Dimension>;
using ImportFilterType = itk::VTKImageImport<InputImageType>;
using CastToFloatFilterType = itk::CastImageFilter<InputImageType, InternalImageType>;
using CastBackToInputFilterType = itk::CastImageFilter<InternalImageType, InputImageType>;
using CurvatureFlowImageFilterType = itk::CurvatureFlowImageFilter<InternalImageType, InternalImageType>;
using ConnectedThresholdImageFilterType = itk::ConnectedThresholdImageFilter<InternalImageType, MaskImageType>;
using StructuringElementType = itk::FlatStructuringElement<Dimension>;
using DilateFilterType = itk::BinaryDilateImageFilter<MaskImageType, MaskImageType, StructuringElementType>;
using ErodeFilterType = itk::BinaryErodeImageFilter<MaskImageType, MaskImageType, StructuringElementType>;
using MaskNegatedFilterType = itk::MaskNegatedImageFilter<InternalImageType, MaskImageType, InternalImageType>;
} // namespace
vtkImageData* SkullRemovalService::run(vtkImageData* inputVolume,
double thr_low,
double thr_up,
std::atomic_bool* abortFlag,
std::function<void(const std::string&, double)> progressCb)
{
if (!inputVolume)
throw std::runtime_error("SkullRemovalService::run: inputVolume is null");
auto isAborted = [&]() -> bool { return abortFlag && abortFlag->load(); };
if (progressCb) progressCb("Skull masking: preparing pipeline", 0.0);
if (isAborted()) return nullptr;
// VTK -> ITK
vtkSmartPointer<vtkImageExport> vtkExporter = vtkSmartPointer<vtkImageExport>::New();
vtkExporter->SetInputData(inputVolume);
auto itkImporter = ImportFilterType::New();
ConnectVTKToITK(vtkExporter.GetPointer(), itkImporter);
// Pipeline (mirrors old gSkullRemoval::runFilter in a pure form)
auto castToFloat = CastToFloatFilterType::New();
auto smoothing = CurvatureFlowImageFilterType::New();
auto connectedThreshold = ConnectedThresholdImageFilterType::New();
auto binaryDilateFilter = DilateFilterType::New();
auto binaryErodeFilter = ErodeFilterType::New();
auto maskNegatedFilter = MaskNegatedFilterType::New();
smoothing->SetNumberOfIterations(10);
smoothing->SetTimeStep(0.125);
castToFloat->SetInput(itkImporter->GetOutput());
smoothing->SetInput(castToFloat->GetOutput());
connectedThreshold->SetInput(smoothing->GetOutput());
connectedThreshold->SetLower(static_cast<InternalPixelType>(thr_low));
connectedThreshold->SetUpper(static_cast<InternalPixelType>(thr_up));
connectedThreshold->SetReplaceValue(255);
// Seed at volume center
// connectedThreshold->UpdateOutputInformation();
// Make sure upstream information is propagated
connectedThreshold->UpdateOutputInformation();
// Use the image that DEFINITELY has region info propagated
auto in = smoothing->GetOutput(); // or connectedThreshold->GetInput()
const auto region = in->GetLargestPossibleRegion();
const auto size = region.GetSize();
const auto start = region.GetIndex();
InternalImageType::IndexType seed;
seed[0] = start[0] + static_cast<long>(size[0] / 2);
seed[1] = start[1] + static_cast<long>(size[1] / 2);
seed[2] = start[2] + static_cast<long>(size[2] / 2);
connectedThreshold->SetSeed(seed);
// InternalImageType::Pointer in = castToFloat->GetOutput();
// const auto region = in->GetBufferedRegion();
// const auto size = region.GetSize();
// const auto start = region.GetIndex();
// InternalImageType::IndexType seed;
// seed[0] = start[0] + static_cast<long>(size[0] / 2);
// seed[1] = start[1] + static_cast<long>(size[1] / 2);
// seed[2] = start[2] + static_cast<long>(size[2] / 2);
// connectedThreshold->SetSeed(seed);
// ITK 5.x expects a RadiusType (itk::Size<Dimension>) for FlatStructuringElement::Ball
StructuringElementType::RadiusType radius;
radius.Fill(10);
const auto kernel = StructuringElementType::Ball(radius);
binaryDilateFilter->SetKernel(kernel);
binaryErodeFilter->SetKernel(kernel);
binaryDilateFilter->SetDilateValue(255);
binaryErodeFilter->SetErodeValue(255);
binaryDilateFilter->SetInput(connectedThreshold->GetOutput());
binaryErodeFilter->SetInput(binaryDilateFilter->GetOutput());
maskNegatedFilter->SetInput1(castToFloat->GetOutput());
maskNegatedFilter->SetInput2(binaryErodeFilter->GetOutput());
if (progressCb) progressCb("Skull masking: running", 0.5);
if (isAborted()) return nullptr;
try {
maskNegatedFilter->Update();
} catch (const itk::ExceptionObject& ex) {
if (isAborted()) return nullptr;
throw std::runtime_error(std::string("SkullRemovalService: ITK exception: ") + ex.GetDescription());
}
if (isAborted()) return nullptr;
// ITK -> VTK (cast back to the original scalar type expected by the rest of the app)
auto castBackToInput = CastBackToInputFilterType::New();
castBackToInput->SetInput(maskNegatedFilter->GetOutput());
vtkSmartPointer<vtkImageImport> vtkImporter = vtkSmartPointer<vtkImageImport>::New();
auto itkExporter = itk::VTKImageExport<InputImageType>::New();
itkExporter->SetInput(castBackToInput->GetOutput());
ConnectITKToVTK(itkExporter, vtkImporter.GetPointer());
vtkImporter->Update();
vtkImageData* out = vtkImageData::New();
out->DeepCopy(vtkImporter->GetOutput());
if (progressCb) progressCb("Skull masking: done", 1.0);
return out;
}