otbTensorflowMultisourceModelFilter.hxx 16.4 KB
Newer Older
remi cresson's avatar
remi cresson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
/*=========================================================================

  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#ifndef otbTensorflowMultisourceModelFilter_txx
#define otbTensorflowMultisourceModelFilter_txx

#include "otbTensorflowMultisourceModelFilter.h"

namespace otb
{

template <class TInputImage, class TOutputImage>
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::TensorflowMultisourceModelFilter()
 {
  m_OutputGridSize.Fill(0);
  m_ForceOutputGridSize = false;
  m_FullyConvolutional = false;

  m_OutputSpacing.Fill(0);
  m_OutputOrigin.Fill(0);
  m_OutputSize.Fill(0);

  m_OutputSpacingScale = 1.0f;

  Superclass::SetCoordinateTolerance(itk::NumericTraits<double>::max() );
  Superclass::SetDirectionTolerance(itk::NumericTraits<double>::max() );
 }

template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::SmartPad(RegionType& region, const SizeType &patchSize)
 {
  for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim)
    {
    const SizeValueType psz = patchSize[dim];
    const SizeValueType rval = 0.5 * psz;
    const SizeValueType lval = psz - rval;
    region.GetModifiableIndex()[dim] -= lval;
    region.GetModifiableSize()[dim] += psz;
    }
 }

template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::SmartShrink(RegionType& region, const SizeType &patchSize)
 {
  for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim)
    {
    const SizeValueType psz = patchSize[dim];
    const SizeValueType rval = 0.5 * psz;
    const SizeValueType lval = psz - rval;
    region.GetModifiableIndex()[dim] += lval;
    region.GetModifiableSize()[dim] -= psz;
    }
 }

/**
  Compute the input image extent i.e. corners inf & sup
  Function taken from "Mosaic" and adapted
 */
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::ImageToExtent(ImageType* image, PointType &extentInf, PointType &extentSup, SizeType &patchSize)
 {

  // Get largest possible region
  RegionType largestPossibleRegion = image->GetLargestPossibleRegion();

  // Shrink it a little with the FOV radius
  SmartShrink(largestPossibleRegion, patchSize);

  // Get index of first and last pixel
  IndexType imageFirstIndex = largestPossibleRegion.GetIndex();
  IndexType imageLastIndex = largestPossibleRegion.GetUpperIndex();

  // Compute extent
  PointType imageOrigin;
  PointType imageEnd;
  image->TransformIndexToPhysicalPoint(imageLastIndex, imageEnd);
  image->TransformIndexToPhysicalPoint(imageFirstIndex, imageOrigin);
  for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim)
    {
    extentInf[dim] = vnl_math_min(imageOrigin[dim], imageEnd[dim]);
    extentSup[dim] = vnl_math_max(imageOrigin[dim], imageEnd[dim]);
    }

 }

/**
  Compute the region of the input image which correspond to the given output requested region
  Return true if the region exists, false if not
  Function taken from "Mosaic"
 */
template <class TInputImage, class TOutputImage>
bool
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::OutputRegionToInputRegion(const RegionType &outputRegion, RegionType &inputRegion, ImageType* &inputImage)
 {

  // Mosaic Region Start & End (mosaic image index)
  const IndexType outIndexStart = outputRegion.GetIndex();
  const IndexType outIndexEnd = outputRegion.GetUpperIndex();

  // Mosaic Region Start & End (geo)
  PointType outPointStart, outPointEnd;
  this->GetOutput()->TransformIndexToPhysicalPoint(outIndexStart, outPointStart);
  this->GetOutput()->TransformIndexToPhysicalPoint(outIndexEnd  , outPointEnd  );

  // Add the half-width pixel size of the input image
  // and remove the half-width pixel size of the output image
  // (coordinates = pixel center)
  const SpacingType outputSpc = this->GetOutput()->GetSpacing();
  const SpacingType inputSpc = inputImage->GetSpacing();
  for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim)
    {
    const typename SpacingType::ValueType border =
        0.5 * (inputSpc[dim] - outputSpc[dim]);
    if (outPointStart[dim] < outPointEnd[dim])
      {
      outPointStart[dim] += border;
      outPointEnd  [dim] -= border;
      }
    else
      {
      outPointStart[dim] -= border;
      outPointEnd  [dim] += border;
      }
    }

  // Mosaic Region Start & End (input image index)
  IndexType defIndexStart, defIndexEnd;
  inputImage->TransformPhysicalPointToIndex(outPointStart, defIndexStart);
  inputImage->TransformPhysicalPointToIndex(outPointEnd  , defIndexEnd);

  // Compute input image region
  for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim)
    {
    inputRegion.SetIndex(dim, vnl_math_min(defIndexStart[dim], defIndexEnd[dim]));
    inputRegion.SetSize(dim, vnl_math_max(defIndexStart[dim], defIndexEnd[dim]) - inputRegion.GetIndex(dim) + 1);
    }

  // crop the input requested region at the input's largest possible region
  return inputRegion.Crop( inputImage->GetLargestPossibleRegion() );

 }

/*
 * Enlarge the given region to the nearest aligned region.
 * Aligned region = Index and UpperIndex+1 are on the output grid
 */
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::EnlargeToAlignedRegion(RegionType& region)
 {
  for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim)
    {
    // Get corners
    IndexValueType lower = region.GetIndex(dim);
    IndexValueType upper = lower + region.GetSize(dim);

    // Compute deltas between corners and the grid
    const IndexValueType deltaLo = lower % m_OutputGridSize[dim];
    const IndexValueType deltaUp = upper % m_OutputGridSize[dim];

    // Move corners to aligned positions
    lower -= deltaLo;
    if (deltaUp > 0)
      {
      upper += m_OutputGridSize[dim] - deltaUp;
      }

    // Update region
    region.SetIndex(dim, lower);
    region.SetSize(dim, upper - lower);

    }
 }

template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::GenerateOutputInformation()
 {

  Superclass::GenerateOutputInformation();

  //////////////////////////////////////////////////////////////////////////////////////////
  //                            Compute the output image extent
  //////////////////////////////////////////////////////////////////////////////////////////

  // If the output spacing is not specified, we use the first input image as grid reference
  m_OutputSpacing = this->GetInput(0)->GetSignedSpacing();
  m_OutputSpacing[0] *= m_OutputSpacingScale;
  m_OutputSpacing[1] *= m_OutputSpacingScale;
  PointType extentInf, extentSup;
  extentSup.Fill(itk::NumericTraits<double>::max());
  extentInf.Fill(itk::NumericTraits<double>::NonpositiveMin());

  // Compute the extent of each input images and update the global extent
  for (unsigned int imageIndex = 0 ; imageIndex < this->GetNumberOfInputs() ; imageIndex++)
    {
    ImageType * currentImage = static_cast<ImageType *>(
        Superclass::ProcessObject::GetInput(imageIndex) );

    // Update output image extent
    PointType currentInputImageExtentInf, currentInputImageExtentSup;
remi cresson's avatar
remi cresson committed
219
    ImageToExtent(currentImage, currentInputImageExtentInf, currentInputImageExtentSup, this->GetInputReceptiveFields()[imageIndex]);
remi cresson's avatar
remi cresson committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    for(unsigned int dim = 0; dim<ImageType::ImageDimension; ++dim)
      {
      extentInf[dim] = vnl_math_max(currentInputImageExtentInf[dim], extentInf[dim]);
      extentSup[dim] = vnl_math_min(currentInputImageExtentSup[dim], extentSup[dim]);
      }
    }

  // Set final size
  m_OutputSize[0] = vcl_floor( (extentSup[0] - extentInf[0]) / vcl_abs(m_OutputSpacing[0]) ) + 1;
  m_OutputSize[1] = vcl_floor( (extentSup[1] - extentInf[1]) / vcl_abs(m_OutputSpacing[1]) ) + 1;

  // Set final origin
  m_OutputOrigin[0] =  extentInf[0];
  m_OutputOrigin[1] =  extentSup[1];

  // Set output grid size
  if (!m_ForceOutputGridSize)
    {
    // Default is the output field of expression
remi cresson's avatar
remi cresson committed
239
    m_OutputGridSize = this->GetOutputExpressionFields().at(0);
remi cresson's avatar
remi cresson committed
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    }

  // Resize the largestPossibleRegion to be a multiple of the grid size
  for(unsigned int dim = 0; dim<ImageType::ImageDimension; ++dim)
    {
    if (m_OutputGridSize[dim] > m_OutputSize[dim])
      itkGenericExceptionMacro("Output grid size is larger than output image size !");
    m_OutputSize[dim] -= m_OutputSize[dim] % m_OutputGridSize[dim];
    }

  // Set the largest possible region
  RegionType largestPossibleRegion;
  largestPossibleRegion.SetSize(m_OutputSize);

  //////////////////////////////////////////////////////////////////////////////////////////
  //                  Compute the output number of components per pixel
  //////////////////////////////////////////////////////////////////////////////////////////

  unsigned int outputPixelSize = 0;
  for (auto& protoShape: this->GetOutputTensorsShapes())
    {
261
    // The number of components per pixel is the last dimension of the tensor
remi cresson's avatar
remi cresson committed
262
    int dim_size = protoShape.dim_size();
263
    unsigned int nComponents = 1;
264
    if (1 < dim_size && dim_size <= 4)
remi cresson's avatar
remi cresson committed
265
      {
266
      nComponents = protoShape.dim(dim_size-1).size();
remi cresson's avatar
remi cresson committed
267
      }
268
    else if (dim_size > 4)
remi cresson's avatar
remi cresson committed
269
270
271
272
273
274
275
276
277
278
279
280
281
      {
      itkExceptionMacro("Dim_size=" << dim_size << " currently not supported.");
      }
    outputPixelSize += nComponents;
    }

  // Copy input image projection
  ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(0) );
  const std::string projectionRef = inputImage->GetProjectionRef();

  // Set output image origin/spacing/size/projection
  ImageType * outputPtr = this->GetOutput();
  outputPtr->SetNumberOfComponentsPerPixel(outputPixelSize);
282
283
284
285
  outputPtr->SetProjectionRef        ( projectionRef );
  outputPtr->SetOrigin               ( m_OutputOrigin );
  outputPtr->SetSignedSpacing        ( m_OutputSpacing );
  outputPtr->SetLargestPossibleRegion( largestPossibleRegion );
remi cresson's avatar
remi cresson committed
286

Cresson Remi's avatar
Cresson Remi committed
287
288
289
290
  // Set null pixel
  m_NullPixel.SetSize(outputPtr->GetNumberOfComponentsPerPixel());
  m_NullPixel.Fill(0);

remi cresson's avatar
remi cresson committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
 }

template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::GenerateInputRequestedRegion()
 {
  Superclass::GenerateInputRequestedRegion();

  // Output requested region
  RegionType requestedRegion = this->GetOutput()->GetRequestedRegion();

  // First, align the output region
  EnlargeToAlignedRegion(requestedRegion);

  // For each image, get the requested region
  for(unsigned int i = 0; i < this->GetNumberOfInputs(); ++i)
    {
    ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(i) );

    // Compute the requested region
    RegionType inRegion;
    if (!OutputRegionToInputRegion(requestedRegion, inRegion, inputImage) )
      {
      // Image does not overlap requested region: set requested region to null
      itkDebugMacro( <<  "Image #" << i << " :\n" << inRegion << " is outside the requested region");
      inRegion.GetModifiableIndex().Fill(0);
      inRegion.GetModifiableSize().Fill(0);
      }

    // Compute the FOV-scale*FOE radius to pad
remi cresson's avatar
remi cresson committed
322
323
324
    SizeType toPad(this->GetInputReceptiveFields().at(i));
    toPad[0] -= 1 + (this->GetOutputExpressionFields().at(0)[0] - 1) * m_OutputSpacingScale;
    toPad[1] -= 1 + (this->GetOutputExpressionFields().at(0)[1] - 1) * m_OutputSpacingScale;
remi cresson's avatar
remi cresson committed
325
326
327
328
329
330
331

    // Pad with radius
    SmartPad(inRegion, toPad);

    // We need to avoid some extrapolation when mode is patch-based.
    // The reason is that, when some input have a lower spacing than the
    // reference image, the requested region of this lower res input image
332
    // can be one pixel larger when the input image regions are not physically
remi cresson's avatar
remi cresson committed
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
    // aligned.
    if (!m_FullyConvolutional)
      {
      inRegion.PadByRadius(1);
      }

    inRegion.Crop(inputImage->GetLargestPossibleRegion());

    // Update the requested region
    inputImage->SetRequestedRegion(inRegion);

    } // next image

 }

/**
 * Compute the output image
 */
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::GenerateData()
 {
  // Output pointer and requested region
  typename TOutputImage::Pointer outputPtr = this->GetOutput();
  const RegionType outputReqRegion = outputPtr->GetRequestedRegion();

  // Get the aligned output requested region
  RegionType outputAlignedReqRegion(outputReqRegion);
  EnlargeToAlignedRegion(outputAlignedReqRegion);

  // Add a progress reporter
  itk::ProgressReporter progress(this, 0, outputReqRegion.GetNumberOfPixels());

  const unsigned int nInputs = this->GetNumberOfInputs();

  // Create input tensors list
remi cresson's avatar
remi cresson committed
370
  DictType inputs;
remi cresson's avatar
remi cresson committed
371
372
373
374
375
376
377
378

  // Populate input tensors
  for (unsigned int i = 0 ; i < nInputs ; i++)
    {
    // Input image pointer
    const ImagePointerType inputPtr = const_cast<TInputImage*>(this->GetInput(i));

    // Patch size of tensor #i
remi cresson's avatar
remi cresson committed
379
    const SizeType inputPatchSize = this->GetInputReceptiveFields().at(i);
remi cresson's avatar
remi cresson committed
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398

    // Input image requested region
    const RegionType reqRegion = inputPtr->GetRequestedRegion();

    if (m_FullyConvolutional)
      {
      // Shape of input tensor #i
      tensorflow::int64 sz_n = 1;
      tensorflow::int64 sz_y = reqRegion.GetSize(1);
      tensorflow::int64 sz_x = reqRegion.GetSize(0);
      tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel();
      tensorflow::TensorShape inputTensorShape({sz_n, sz_y, sz_x, sz_c});

      // Create the input tensor
      tensorflow::Tensor inputTensor(this->GetInputTensorsDataTypes()[i], inputTensorShape);

      // Recopy the whole input
      tf::RecopyImageRegionToTensorWithCast<TInputImage>(inputPtr, reqRegion, inputTensor, 0);

399
      // Input is the tensor representing the subset of image
Cresson Remi's avatar
Cresson Remi committed
400
401
      DictElementType input = { this->GetInputPlaceholders()[i], inputTensor };
      inputs.push_back(input);
remi cresson's avatar
remi cresson committed
402
403
404
      }
    else
      {
Cresson Remi's avatar
Cresson Remi committed
405
      // Preparing patches
remi cresson's avatar
remi cresson committed
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
      // Shape of input tensor #i
      tensorflow::int64 sz_n = outputReqRegion.GetNumberOfPixels();
      tensorflow::int64 sz_y = inputPatchSize[1];
      tensorflow::int64 sz_x = inputPatchSize[0];
      tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel();
      tensorflow::TensorShape inputTensorShape({sz_n, sz_y, sz_x, sz_c});

      // Create the input tensor
      tensorflow::Tensor inputTensor(this->GetInputTensorsDataTypes()[i], inputTensorShape);

      // Fill the input tensor.
      // We iterate over points which are located from the index iterator
      // moving through the output image requested region
      unsigned int elemIndex = 0;
      IndexIteratorType idxIt(outputPtr, outputReqRegion);
      for (idxIt.GoToBegin(); !idxIt.IsAtEnd(); ++idxIt)
        {
        // Get the coordinates of the current output pixel
        PointType point;
        outputPtr->TransformIndexToPhysicalPoint(idxIt.GetIndex(), point);

        // Sample the i-th input patch centered on the point
        tf::SampleCenteredPatch<TInputImage>(inputPtr, point, inputPatchSize, inputTensor, elemIndex);
        elemIndex++;
        }

432
      // Input is the tensor of patches (aka the batch)
Cresson Remi's avatar
Cresson Remi committed
433
434
435
      DictElementType input = { this->GetInputPlaceholders()[i], inputTensor };
      inputs.push_back(input);

remi cresson's avatar
remi cresson committed
436
437
438
439
440
441
442
443
444
445
446
      } // mode is not full convolutional

    } // next input tensor

  // Run session
  TensorListType outputs;
  this->RunSession(inputs, outputs);

  // Fill the output buffer with zero value
  outputPtr->SetBufferedRegion(outputReqRegion);
  outputPtr->Allocate();
Cresson Remi's avatar
Cresson Remi committed
447
  outputPtr->FillBuffer(m_NullPixel);
remi cresson's avatar
remi cresson committed
448
449
450
451
452
453
454

  // Get output tensors
  int bandOffset = 0;
  for (unsigned int i = 0 ; i < outputs.size() ; i++)
    {
    // The offset (i.e. the starting index of the channel for the output tensor) is updated
    // during this call
Cresson Remi's avatar
Cresson Remi committed
455
    // TODO: implement a generic strategy enabling expression field copy in patch-based mode (see tf::CopyTensorToImageRegion)
remi cresson's avatar
remi cresson committed
456
457
458
459
460
461
462
    try
      {
      tf::CopyTensorToImageRegion<TOutputImage> (outputs[i],
          outputAlignedReqRegion, outputPtr, outputReqRegion, bandOffset);
      }
    catch( itk::ExceptionObject & err )
      {
Cresson Remi's avatar
Cresson Remi committed
463
      std::stringstream debugMsg = this->GenerateDebugReport(inputs);
remi cresson's avatar
remi cresson committed
464
465
466
467
      itkExceptionMacro("Error occured during tensor to image conversion.\n"
          << "Context: " << debugMsg.str()
          << "Error:" << err);
      }
remi cresson's avatar
remi cresson committed
468
469
470
471
472
473
474
475
476
    }

 }


} // end namespace otb


#endif