go home Home | Main Page | Topics | Namespace List | Class Hierarchy | Alphabetical List | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages
Loading...
Searching...
No Matches
ImpactLoss.h
Go to the documentation of this file.
1/*=========================================================================
2 *
3 * Copyright UMC Utrecht and contributors
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18
35
36#ifndef _ImpactLoss_h
37#define _ImpactLoss_h
38
39#include <torch/torch.h>
40#include <cmath>
41#include <iostream>
42#include "itkTimeProbe.h"
43
44namespace ImpactLoss
45{
46
60class Loss
61{
62private:
63 mutable double m_Normalization = 0;
64
65protected:
66 double m_Value;
67 torch::Tensor m_Derivative;
68 bool m_Initialized = false;
70
71public:
72 Loss(bool isLossNormalized)
73 {
74 if (!isLossNormalized)
75 {
76 m_Normalization = 1.0;
77 }
78 }
79
80 void
81 setNumberOfParameters(int numberOfParameters)
82 {
83 m_NumberOfParameters = numberOfParameters;
84 }
85 void
87 {
88 m_Initialized = false;
89 }
90
91 virtual void
92 initialize(torch::Tensor & output)
93 {
94 // Lazy initialization of internal buffers based on output tensor shape and number of parameters
95 if (!m_Initialized)
96 {
97 m_Value = 0;
98 m_Derivative = torch::zeros({ m_NumberOfParameters }, output.options());
99 m_Initialized = true;
100 }
101 }
102
103 virtual void
104 updateValue(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) = 0;
105 virtual void
106 updateValueAndDerivativeInStaticMode(torch::Tensor & fixedOutput,
107 torch::Tensor & movingOutput,
108 torch::Tensor & jacobian,
109 torch::Tensor & nonZeroJacobianIndices)
110 {
111 m_Derivative.index_add_(
112 0,
113 nonZeroJacobianIndices.flatten(),
114 (updateValueAndGetGradientModulator(fixedOutput, movingOutput).unsqueeze(-1) * jacobian).sum(1).flatten());
115 }
116 virtual torch::Tensor
117 updateValueAndGetGradientModulator(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) = 0;
118 void
119 updateDerivativeInJacobianMode(torch::Tensor & jacobian, torch::Tensor & nonZeroJacobianIndices)
120 {
121 m_Derivative.index_add_(0, nonZeroJacobianIndices.flatten(), jacobian.flatten());
122 }
123
124 virtual double
125 GetValue(double N) const
126 {
127 if (m_Normalization == 0)
128 {
129 m_Normalization = 1 / (m_Value / N);
130 }
131 return m_Normalization * m_Value / N;
132 }
133
134 virtual torch::Tensor
135 GetDerivative(double N) const
136 {
137 return m_Normalization * m_Derivative.to(torch::kCPU) / N;
138 }
139
140 virtual ~Loss() = default;
141
142 virtual Loss &
143 operator+=(const Loss & other)
144 {
145 if (!m_Initialized && other.m_Initialized)
146 {
147 m_Value = other.m_Value;
149 m_Initialized = true;
150 }
151 else if (other.m_Initialized)
152 {
153 m_Value += other.m_Value;
154 m_Derivative += other.m_Derivative;
155 }
156 return *this;
157 }
158};
159
168{
169public:
170 using CreatorFunc = std::function<std::unique_ptr<Loss>()>;
171
172 static LossFactory &
174 {
175 static LossFactory instance;
176 return instance;
177 }
178
179 void
180 RegisterLoss(const std::string & name, CreatorFunc creator)
181 {
182 factoryMap[name] = creator;
183 }
184
185 std::unique_ptr<Loss>
186 Create(const std::string & name)
187 {
188 auto it = factoryMap.find(name);
189 if (it != factoryMap.end())
190 {
191 return it->second();
192 }
193 throw std::runtime_error("Error: Unknown loss function " + name);
194 }
195
196private:
197 std::unordered_map<std::string, CreatorFunc> factoryMap;
198};
199
200template <typename T>
202{
203public:
204 RegisterLoss(const std::string & name)
205 {
206 LossFactory::Instance().RegisterLoss(name, []() { return std::make_unique<T>(); });
207 }
208};
209
214class L1 : public Loss
215{
216public:
218 : Loss(true)
219 {}
220
221 void
222 updateValue(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) override
223 {
224 this->initialize(fixedOutput);
225 this->m_Value += (fixedOutput - movingOutput).abs().mean(1).sum().item<double>();
226 }
227
228 torch::Tensor
229 updateValueAndGetGradientModulator(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) override
230 {
231 this->initialize(fixedOutput);
232 torch::Tensor diffOutput = fixedOutput - movingOutput;
233 this->m_Value += diffOutput.abs().mean(1).sum().item<double>();
234 return -torch::sign(diffOutput) / fixedOutput.size(1);
235 }
236};
237
238inline RegisterLoss<L1> L1_reg("L1"); // Register the loss under its string name for factory-based creation
239
244class L2 : public Loss
245{
246public:
248 : Loss(true)
249 {}
250
251 void
252 updateValue(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) override
253 {
254 this->initialize(fixedOutput);
255 this->m_Value += (fixedOutput - movingOutput).pow(2).mean(1).sum().item<double>();
256 }
257
258 torch::Tensor
259 updateValueAndGetGradientModulator(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) override
260 {
261 this->initialize(fixedOutput);
262 torch::Tensor diffOutput = fixedOutput - movingOutput;
263 this->m_Value += diffOutput.pow(2).mean(1).sum().item<double>();
264 return -2 * diffOutput / fixedOutput.size(1);
265 }
266};
267
268
269inline RegisterLoss<L2> MSE_reg("L2"); // Register the loss under its string name for factory-based creation
270
277class Dice : public Loss
278{
279public:
281 : Loss(false)
282 {}
283
284 void
285 updateValue(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) override
286 {
287 this->initialize(fixedOutput);
288 torch::Tensor intersectionSum = (fixedOutput * movingOutput).sum(1); // [N, ...]
289 torch::Tensor unionSum = (fixedOutput + movingOutput).sum(1); // [N, ...]
290
291 // Detect degenerate case: no structure in either output (union == 0)
292 torch::Tensor isEmpty = (unionSum == 0);
293
294 // Make the denominator safe:
295 // - where unionSum != 0 -> keep unionSum
296 // - where unionSum == 0 -> replace by 1
297 // Even though the degenerate positions are overwritten later
298 // (dice.masked_fill_(isEmpty, 1.0)), we must avoid forming 0/0 here,
299 // because the division is evaluated eagerly and would otherwise
300 // create NaNs in intermediate tensors.
301 torch::Tensor unionSumSafe = unionSum + isEmpty.to(unionSum.scalar_type());
302
303 // Standard Dice formulation: 2 * intersection / union
304 torch::Tensor dice = 2.0 * intersectionSum / unionSumSafe;
305
306 // Convention for the degenerate case:
307 // if both outputs are empty, force Dice = 1 (perfect similarity)
308 dice.masked_fill_(isEmpty, 1.0);
309
310 // Accumulate the loss value
311 this->m_Value -= dice.sum().item<double>();
312 }
313
314 torch::Tensor
315 updateValueAndGetGradientModulator(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) override
316 {
317 this->initialize(fixedOutput);
318
319 torch::Tensor intersectionSum = (fixedOutput * movingOutput).sum(1); // [N, ...]
320 torch::Tensor unionSum = (fixedOutput + movingOutput).sum(1); // [N, ...]
321
322 // Detect degenerate case: no structure in either output (union == 0)
323 torch::Tensor isEmpty = (unionSum == 0);
324
325 // Make the denominator safe (see updateValue for rationale)
326 torch::Tensor unionSumSafe = unionSum + isEmpty.to(unionSum.scalar_type());
327
328 // Value: standard Dice formulation
329 torch::Tensor dice = 2.0 * intersectionSum / unionSumSafe;
330
331 // Convention for the degenerate case: empty/empty => Dice = 1
332 dice.masked_fill_(isEmpty, 1.0);
333 this->m_Value -= dice.sum().item<double>();
334
335 // Gradient modulator:
336 // standard: -2 * (fixedOutput * v - u) / v^2
337 torch::Tensor grad = -2.0 * (fixedOutput * unionSumSafe.unsqueeze(-1) - intersectionSum.unsqueeze(-1)) /
338 (unionSumSafe * unionSumSafe).unsqueeze(-1);
339
340 // empty/empty => gradient = 0
341 grad.masked_fill_(isEmpty.unsqueeze(-1), 0.0);
342
343 return grad;
344 }
345};
346
347
348inline RegisterLoss<Dice> Dice_reg("Dice"); // Register the loss under its string name for factory-based creation
349
356class L1Cosine : public Loss
357{
358private:
359 double m_Lambda;
360
361public:
363 : Loss(false)
364 {
365 m_Lambda = 0.1;
366 }
367
368 void
369 updateValue(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) override
370 {
371 this->initialize(fixedOutput);
372 torch::Tensor dotProduct = (fixedOutput * movingOutput).sum(1);
373 torch::Tensor normFixed = torch::norm(fixedOutput, 2, 1);
374 torch::Tensor normMoving = torch::norm(movingOutput, 2, 1);
375 torch::Tensor cosine = dotProduct / (normFixed * normMoving);
376 torch::Tensor expL1 = torch::exp(-m_Lambda * (fixedOutput - movingOutput).abs());
377 this->m_Value -= (cosine.unsqueeze(-1) * expL1).mean(1).sum().item<double>();
378 }
379
380 torch::Tensor
381 updateValueAndGetGradientModulator(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) override
382 {
383 this->initialize(fixedOutput);
384 torch::Tensor diffOutput = fixedOutput - movingOutput;
385 torch::Tensor dotProduct = (fixedOutput * movingOutput).sum(1);
386 torch::Tensor normFixed = torch::norm(fixedOutput, 2, 1);
387 torch::Tensor normMoving = torch::norm(movingOutput, 2, 1);
388 torch::Tensor v = (normFixed * normMoving);
389
390 torch::Tensor cosine = dotProduct / (v);
391 torch::Tensor expL1 = torch::exp(-m_Lambda * (fixedOutput - movingOutput).abs());
392
393 torch::Tensor dCosine = -(fixedOutput / v.unsqueeze(-1) -
394 (fixedOutput * movingOutput * movingOutput) / (v * normMoving.pow(2)).unsqueeze(-1));
395 torch::Tensor dexpL1 = -torch::sign(diffOutput) * expL1 / fixedOutput.size(1);
396 this->m_Value -= (cosine.unsqueeze(-1) * expL1).mean(1).sum().item<double>();
397 return dCosine * dexpL1 + cosine.unsqueeze(-1) * dexpL1;
398 }
399};
400
402 "L1Cosine"); // Register the loss under its string name for factory-based creation
403
408class Cosine : public Loss
409{
410public:
412 : Loss(false)
413 {}
414
415 void
416 updateValue(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) override
417 {
418 this->initialize(fixedOutput);
419 torch::Tensor dotProduct = (fixedOutput * movingOutput).sum(1);
420 torch::Tensor normFixed = torch::norm(fixedOutput, 2, 1);
421 torch::Tensor normMoving = torch::norm(movingOutput, 2, 1);
422 this->m_Value -= (dotProduct / (normFixed * normMoving)).sum().item<double>();
423 }
424
425 torch::Tensor
426 updateValueAndGetGradientModulator(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) override
427 {
428 this->initialize(fixedOutput);
429 torch::Tensor dotProduct = (fixedOutput * movingOutput).sum(1);
430 torch::Tensor normFixed = torch::norm(fixedOutput, 2, 1);
431 torch::Tensor normMoving = torch::norm(movingOutput, 2, 1);
432 torch::Tensor v = (normFixed * normMoving);
433 this->m_Value -= (dotProduct / v).sum().item<double>();
434 return -(fixedOutput / v.unsqueeze(-1) -
435 (fixedOutput * movingOutput * movingOutput) / (v * normMoving.pow(2)).unsqueeze(-1));
436 }
437};
438
439inline RegisterLoss<Cosine> CosineReg("Cosine"); // Register the loss under its string name for factory-based creation
440
445class DotProduct : public Loss
446{
447public:
449 : Loss(false)
450 {}
451
452 void
453 updateValue(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) override
454 {
455 this->initialize(fixedOutput);
456 this->m_Value -= (fixedOutput * movingOutput).sum(1).sum().item<double>();
457 }
458
459 torch::Tensor
460 updateValueAndGetGradientModulator(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) override
461 {
462 this->initialize(fixedOutput);
463 this->m_Value -= (fixedOutput * movingOutput).sum(1).sum().item<double>();
464 return -fixedOutput;
465 }
466};
467
468
470 "DotProduct"); // Register the loss under its string name for factory-based creation
471
479class NCC : public Loss
480{
481private:
482 torch::Tensor m_Sff, m_Smm, m_Sfm, m_Sf, m_Sm;
483 torch::Tensor m_Sfdm, m_Smdm, m_Sdm;
484
485public:
487 : Loss(false)
488 {}
489
490 void
491 initialize(torch::Tensor & output) override
492 {
493 if (!this->m_Initialized)
494 {
495 m_Sff = torch::zeros({ output.size(1) }, output.options());
496 m_Smm = torch::zeros({ output.size(1) }, output.options());
497 m_Sfm = torch::zeros({ output.size(1) }, output.options());
498 m_Sf = torch::zeros({ output.size(1) }, output.options());
499 m_Sm = torch::zeros({ output.size(1) }, output.options());
500 m_Initialized = true;
501 }
502 }
503
504 void
505 updateValue(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) override
506 {
507 this->initialize(fixedOutput);
508 m_Sff += (fixedOutput * fixedOutput).sum(0);
509 m_Smm += (movingOutput * movingOutput).sum(0);
510 m_Sfm += (fixedOutput * movingOutput).sum(0);
511 m_Sf += fixedOutput.sum(0);
512 m_Sm += movingOutput.sum(0);
513 }
514
515 void
516 updateValueAndDerivativeInStaticMode(torch::Tensor & fixedOutput,
517 torch::Tensor & movingOutput,
518 torch::Tensor & jacobian,
519 torch::Tensor & nonZeroJacobianIndices) override
520 {
521 // Accumulate first-order statistics and weighted Jacobians
522 // sfdm: sum(fixed * dM), smdm: sum(moving * dM), sdm: sum(dM)
523 if (!this->m_Initialized)
524 {
525 m_Sfdm = torch::zeros({ fixedOutput.size(1), m_NumberOfParameters }, fixedOutput.options());
526 m_Smdm = torch::zeros({ fixedOutput.size(1), m_NumberOfParameters }, fixedOutput.options());
527 m_Sdm = torch::zeros({ fixedOutput.size(1), m_NumberOfParameters }, fixedOutput.options());
528 }
529 this->updateValue(fixedOutput, movingOutput);
530 m_Sfdm.index_add_(
531 1, nonZeroJacobianIndices.flatten(), (fixedOutput.unsqueeze(-1) * jacobian).permute({ 1, 0, 2 }).flatten(1, 2));
532 m_Smdm.index_add_(
533 1, nonZeroJacobianIndices.flatten(), (movingOutput.unsqueeze(-1) * jacobian).permute({ 1, 0, 2 }).flatten(1, 2));
534 m_Sdm.index_add_(1, nonZeroJacobianIndices.flatten(), (jacobian).permute({ 1, 0, 2 }).flatten(1, 2));
535 }
536
537 torch::Tensor
538 updateValueAndGetGradientModulator(torch::Tensor & fixedOutput, torch::Tensor & movingOutput) override
539 {
540 if (!this->m_Initialized)
541 {
542 this->m_Derivative = torch::zeros({ this->m_NumberOfParameters }, fixedOutput.options());
543 }
544 this->initialize(fixedOutput);
545
546 const double N = fixedOutput.size(0);
547 torch::Tensor sff = (fixedOutput * fixedOutput).sum(0);
548 torch::Tensor smm = (movingOutput * movingOutput).sum(0);
549 torch::Tensor sfm = (fixedOutput * movingOutput).sum(0);
550 torch::Tensor sf = fixedOutput.sum(0);
551 torch::Tensor sm = movingOutput.sum(0);
552
553 m_Sff += sff;
554 m_Smm += smm;
555 m_Sfm += sfm;
556 m_Sf += sf;
557 m_Sm += sm;
558
559 torch::Tensor u = sfm - (sf * sm / N);
560 torch::Tensor v = torch::sqrt(sff - sf * sf / N) * torch::sqrt(smm - sm * sm / N); // v = a*b
561
562 torch::Tensor u_p = fixedOutput - sf.unsqueeze(0) / N;
563 return -((u_p - u.unsqueeze(0) * (movingOutput - sm.unsqueeze(0) / N) / (smm - sm * sm / N).unsqueeze(0)) /
564 v.unsqueeze(0)) /
565 fixedOutput.size(1);
566 }
567
568 double
569 GetValue(double N) const override
570 {
571 // Compute NCC loss from accumulated statistics: mean( -NCC(channel) )
572 if (N <= 0)
573 return 0.0;
574 torch::Tensor u = m_Sfm - (m_Sf * m_Sm / N);
575 torch::Tensor v = torch::sqrt(m_Sff - m_Sf * m_Sf / N) * torch::sqrt(m_Smm - m_Sm * m_Sm / N);
576 return -(u / v).mean().item<double>();
577 }
578
579 torch::Tensor
580 GetDerivative(double N) const override
581 {
582 if (this->m_Derivative.defined())
583 {
584 return this->m_Derivative.to(torch::kCPU);
585 }
586
587 torch::Tensor u = m_Sfm - (m_Sf * m_Sm / N);
588 torch::Tensor v = torch::sqrt(m_Sff - m_Sf * m_Sf / N) * torch::sqrt(m_Smm - m_Sm * m_Sm / N);
589 torch::Tensor u_p = m_Sfdm - m_Sf.unsqueeze(-1) * m_Sdm / N;
590 return -((u_p -
591 u.unsqueeze(-1) * (m_Smdm - m_Sm.unsqueeze(-1) * m_Sdm / N) / (m_Smm - m_Sm * m_Sm / N).unsqueeze(-1)) /
592 v.unsqueeze(-1))
593 .mean(0)
594 .to(torch::kCPU);
595 }
596
597 NCC &
598 operator+=(const Loss & other) override
599 {
600 const auto * nccOther = dynamic_cast<const NCC *>(&other);
601 if (nccOther)
602 {
603 m_Sff += nccOther->m_Sff;
604 m_Smm += nccOther->m_Smm;
605 m_Sfm += nccOther->m_Sfm;
606 m_Sf += nccOther->m_Sf;
607 m_Sm += nccOther->m_Sm;
608 if (m_Sfdm.defined())
609 {
610 m_Sfdm += nccOther->m_Sfdm;
611 m_Smdm += nccOther->m_Smdm;
612 m_Sdm += nccOther->m_Sdm;
613 }
614 if (m_Derivative.defined())
615 {
616 m_Derivative += nccOther->m_Derivative;
617 }
618 }
619 return *this;
620 }
621};
622
623inline RegisterLoss<NCC> NCC_reg("NCC"); // Register the loss under its string name for factory-based creation
624
625} // namespace ImpactLoss
626
627#endif // _ImpactLoss_h
torch::Tensor updateValueAndGetGradientModulator(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
Definition ImpactLoss.h:426
void updateValue(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
Definition ImpactLoss.h:416
torch::Tensor updateValueAndGetGradientModulator(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
Definition ImpactLoss.h:315
void updateValue(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
Definition ImpactLoss.h:285
void updateValue(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
Definition ImpactLoss.h:453
torch::Tensor updateValueAndGetGradientModulator(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
Definition ImpactLoss.h:460
void updateValue(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
Definition ImpactLoss.h:369
torch::Tensor updateValueAndGetGradientModulator(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
Definition ImpactLoss.h:381
void updateValue(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
Definition ImpactLoss.h:222
torch::Tensor updateValueAndGetGradientModulator(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
Definition ImpactLoss.h:229
void updateValue(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
Definition ImpactLoss.h:252
torch::Tensor updateValueAndGetGradientModulator(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
Definition ImpactLoss.h:259
Singleton factory to register and create Loss instances by string name.
Definition ImpactLoss.h:168
std::function< std::unique_ptr< Loss >()> CreatorFunc
Definition ImpactLoss.h:170
std::unordered_map< std::string, CreatorFunc > factoryMap
Definition ImpactLoss.h:197
static LossFactory & Instance()
Definition ImpactLoss.h:173
void RegisterLoss(const std::string &name, CreatorFunc creator)
Definition ImpactLoss.h:180
std::unique_ptr< Loss > Create(const std::string &name)
Definition ImpactLoss.h:186
Loss(bool isLossNormalized)
Definition ImpactLoss.h:72
virtual torch::Tensor GetDerivative(double N) const
Definition ImpactLoss.h:135
virtual Loss & operator+=(const Loss &other)
Definition ImpactLoss.h:143
virtual void updateValueAndDerivativeInStaticMode(torch::Tensor &fixedOutput, torch::Tensor &movingOutput, torch::Tensor &jacobian, torch::Tensor &nonZeroJacobianIndices)
Definition ImpactLoss.h:106
virtual void initialize(torch::Tensor &output)
Definition ImpactLoss.h:92
void setNumberOfParameters(int numberOfParameters)
Definition ImpactLoss.h:81
torch::Tensor m_Derivative
Definition ImpactLoss.h:67
virtual torch::Tensor updateValueAndGetGradientModulator(torch::Tensor &fixedOutput, torch::Tensor &movingOutput)=0
virtual double GetValue(double N) const
Definition ImpactLoss.h:125
virtual void updateValue(torch::Tensor &fixedOutput, torch::Tensor &movingOutput)=0
void updateDerivativeInJacobianMode(torch::Tensor &jacobian, torch::Tensor &nonZeroJacobianIndices)
Definition ImpactLoss.h:119
double m_Normalization
Definition ImpactLoss.h:63
virtual ~Loss()=default
Normalized Cross Correlation loss over feature vectors.
Definition ImpactLoss.h:480
torch::Tensor m_Sdm
Definition ImpactLoss.h:483
torch::Tensor m_Sf
Definition ImpactLoss.h:482
torch::Tensor m_Sfdm
Definition ImpactLoss.h:483
torch::Tensor m_Smm
Definition ImpactLoss.h:482
void updateValueAndDerivativeInStaticMode(torch::Tensor &fixedOutput, torch::Tensor &movingOutput, torch::Tensor &jacobian, torch::Tensor &nonZeroJacobianIndices) override
Definition ImpactLoss.h:516
double GetValue(double N) const override
Definition ImpactLoss.h:569
NCC & operator+=(const Loss &other) override
Definition ImpactLoss.h:598
torch::Tensor updateValueAndGetGradientModulator(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
Definition ImpactLoss.h:538
void initialize(torch::Tensor &output) override
Definition ImpactLoss.h:491
torch::Tensor m_Smdm
Definition ImpactLoss.h:483
torch::Tensor GetDerivative(double N) const override
Definition ImpactLoss.h:580
torch::Tensor m_Sff
Definition ImpactLoss.h:482
torch::Tensor m_Sm
Definition ImpactLoss.h:482
torch::Tensor m_Sfm
Definition ImpactLoss.h:482
void updateValue(torch::Tensor &fixedOutput, torch::Tensor &movingOutput) override
Definition ImpactLoss.h:505
RegisterLoss(const std::string &name)
Definition ImpactLoss.h:204
RegisterLoss< L2 > MSE_reg("L2")
RegisterLoss< L1Cosine > L1CosineReg("L1Cosine")
RegisterLoss< DotProduct > DotProductReg("DotProduct")
RegisterLoss< Cosine > CosineReg("Cosine")
RegisterLoss< Dice > Dice_reg("Dice")
RegisterLoss< NCC > NCC_reg("NCC")
RegisterLoss< L1 > L1_reg("L1")


Generated on 1774142652 for elastix by doxygen 1.15.0 elastix logo