19#ifndef itkImpactModelConfiguration_h
20#define itkImpactModelConfiguration_h
22#include <torch/script.h>
23#include <torch/torch.h>
31#include "itkStatisticsImageFilter.h"
42 for (
int i = 0; i < vec.size(); ++i)
45 if (i != vec.size() - 1)
68 unsigned int dimension,
69 unsigned int numberOfChannels,
70 std::vector<unsigned int> patchSize,
71 std::vector<float> voxelSize,
72 std::vector<bool> layersMask,
74 bool useMixedPrecision)
81 ,
m_DataType(useMixedPrecision ? torch::kFloat16 : torch::kFloat32)
83 m_Model = std::make_shared<torch::jit::script::Module>(torch::jit::load(
m_ModelPath, torch::Device(torch::kCPU)));
86 m_nArgs =
m_Model->get_method(
"forward").function().getSchema().arguments().size() - 1;
143 friend std::ostream &
160 const torch::ScalarType &
176 const std::vector<int64_t> &
181 const std::vector<float> &
186 const std::vector<bool> &
193 to(torch::Device device)
const
198 template <
class TImage>
200 setup(
typename TImage::ConstPointer image)
202 auto imageStats = itk::StatisticsImageFilter<TImage>::New();
203 imageStats->SetInput(image);
204 imageStats->Update();
206 torch::Tensor imageStatsTensor = torch::tensor({
static_cast<float>(imageStats->GetMinimum()),
207 static_cast<float>(imageStats->GetMaximum()),
208 static_cast<float>(imageStats->GetMean()),
209 static_cast<float>(imageStats->GetSigma()) },
212 const auto & imageDirection = image->GetDirection();
214 constexpr unsigned int imageDimension = TImage::ImageDimension;
215 torch::Tensor imageDirectionTensor = torch::empty({ imageDimension, imageDimension }, torch::kInt16);
217 for (
unsigned int r = 0; r < imageDimension; ++r)
219 for (
unsigned int c = 0; c < imageDimension; ++c)
221 imageDirectionTensor[r][c] =
static_cast<int16_t
>(std::llround(imageDirection(r, c)));
228 std::vector<torch::jit::IValue>
232 std::vector<torch::jit::IValue> args;
234 args.emplace_back(inputPatch);
247 return m_Model->forward(std::move(args)).toList().vec();
250 const std::vector<std::vector<float>> &
255 const std::vector<std::vector<torch::indexing::TensorIndex>> &
274 std::shared_ptr<torch::jit::script::Module>
m_Model;
std::string GetStringFromVector(const std::vector< T > &vec)
torch::Tensor m_imageStatsTensor
const std::vector< int64_t > & GetPatchSize() const
std::vector< torch::jit::IValue > forward(torch::Tensor inputPatch) const
torch::Tensor m_imageDirectionTensor
ImpactModelConfiguration(std::string modelPath, unsigned int dimension, unsigned int numberOfChannels, std::vector< unsigned int > patchSize, std::vector< float > voxelSize, std::vector< bool > layersMask, bool isStatic, bool useMixedPrecision)
friend std::ostream & operator<<(std::ostream &os, const ImpactModelConfiguration &config)
const torch::ScalarType & GetDataType() const
std::vector< int64_t > m_PatchSize
std::vector< float > m_VoxelSize
ImpactModelConfiguration & operator=(ImpactModelConfiguration &&)=default
ImpactModelConfiguration & operator=(const ImpactModelConfiguration &)=delete
const std::vector< bool > & GetLayersMask() const
std::vector< std::vector< torch::indexing::TensorIndex > > m_CentersIndexLayers
void to(torch::Device device) const
std::shared_ptr< torch::jit::script::Module > m_Model
void setup(typename TImage::ConstPointer image)
unsigned int GetDimension() const
torch::ScalarType m_DataType
const std::string & GetModelPath() const
void SetCentersIndexLayers(std::vector< std::vector< torch::indexing::TensorIndex > > ¢ersIndexLayers)
ImpactModelConfiguration(ImpactModelConfiguration &&)=default
std::vector< bool > m_LayersMask
unsigned int GetNumberOfChannels() const
~ImpactModelConfiguration()=default
ImpactModelConfiguration(const ImpactModelConfiguration &)=delete
std::vector< std::vector< float > > m_PatchIndex
const std::vector< std::vector< torch::indexing::TensorIndex > > & GetCentersIndexLayers() const
const std::vector< float > & GetVoxelSize() const
unsigned int m_NumberOfChannels
const std::vector< std::vector< float > > & GetPatchIndex() const
bool operator==(const ImpactModelConfiguration &rhs) const