otbTensorflowMultisourceModelFilter.hxx 16.7 KB
Newer Older
remi cresson's avatar
remi cresson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
/*=========================================================================

  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#ifndef otbTensorflowMultisourceModelFilter_txx
#define otbTensorflowMultisourceModelFilter_txx

#include "otbTensorflowMultisourceModelFilter.h"

namespace otb
{

template <class TInputImage, class TOutputImage>
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::TensorflowMultisourceModelFilter()
 {
  m_OutputGridSize.Fill(0);
  m_ForceOutputGridSize = false;
  m_FullyConvolutional = false;

  m_OutputSpacing.Fill(0);
  m_OutputOrigin.Fill(0);
  m_OutputSize.Fill(0);

  m_OutputSpacingScale = 1.0f;

  Superclass::SetCoordinateTolerance(itk::NumericTraits<double>::max() );
  Superclass::SetDirectionTolerance(itk::NumericTraits<double>::max() );
 }

template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::SmartPad(RegionType& region, const SizeType &patchSize)
 {
  for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim)
    {
    const SizeValueType psz = patchSize[dim];
    const SizeValueType rval = 0.5 * psz;
    const SizeValueType lval = psz - rval;
    region.GetModifiableIndex()[dim] -= lval;
    region.GetModifiableSize()[dim] += psz;
    }
 }

template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::SmartShrink(RegionType& region, const SizeType &patchSize)
 {
  for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim)
    {
    const SizeValueType psz = patchSize[dim];
60
    const SizeValueType lval = 0.5 * psz;
remi cresson's avatar
remi cresson committed
61
    region.GetModifiableIndex()[dim] += lval;
62
    region.GetModifiableSize()[dim] -= psz - 1;
remi cresson's avatar
remi cresson committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    }
 }

/**
  Compute the input image extent i.e. corners inf & sup
  Function taken from "Mosaic" and adapted
 */
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::ImageToExtent(ImageType* image, PointType &extentInf, PointType &extentSup, SizeType &patchSize)
 {

  // Get largest possible region
  RegionType largestPossibleRegion = image->GetLargestPossibleRegion();

  // Shrink it a little with the FOV radius
  SmartShrink(largestPossibleRegion, patchSize);

  // Get index of first and last pixel
  IndexType imageFirstIndex = largestPossibleRegion.GetIndex();
  IndexType imageLastIndex = largestPossibleRegion.GetUpperIndex();

  // Compute extent
  PointType imageOrigin;
  PointType imageEnd;
  image->TransformIndexToPhysicalPoint(imageLastIndex, imageEnd);
  image->TransformIndexToPhysicalPoint(imageFirstIndex, imageOrigin);
  for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim)
    {
    extentInf[dim] = vnl_math_min(imageOrigin[dim], imageEnd[dim]);
    extentSup[dim] = vnl_math_max(imageOrigin[dim], imageEnd[dim]);
    }

 }

/**
  Compute the region of the input image which correspond to the given output requested region
  Return true if the region exists, false if not
  Function taken from "Mosaic"
 */
template <class TInputImage, class TOutputImage>
bool
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::OutputRegionToInputRegion(const RegionType &outputRegion, RegionType &inputRegion, ImageType* &inputImage)
 {

  // Mosaic Region Start & End (mosaic image index)
  const IndexType outIndexStart = outputRegion.GetIndex();
  const IndexType outIndexEnd = outputRegion.GetUpperIndex();

  // Mosaic Region Start & End (geo)
  PointType outPointStart, outPointEnd;
  this->GetOutput()->TransformIndexToPhysicalPoint(outIndexStart, outPointStart);
  this->GetOutput()->TransformIndexToPhysicalPoint(outIndexEnd  , outPointEnd  );

  // Add the half-width pixel size of the input image
  // and remove the half-width pixel size of the output image
  // (coordinates = pixel center)
  const SpacingType outputSpc = this->GetOutput()->GetSpacing();
  const SpacingType inputSpc = inputImage->GetSpacing();
  for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim)
    {
    const typename SpacingType::ValueType border =
        0.5 * (inputSpc[dim] - outputSpc[dim]);
    if (outPointStart[dim] < outPointEnd[dim])
      {
      outPointStart[dim] += border;
      outPointEnd  [dim] -= border;
      }
    else
      {
      outPointStart[dim] -= border;
      outPointEnd  [dim] += border;
      }
    }

  // Mosaic Region Start & End (input image index)
  IndexType defIndexStart, defIndexEnd;
  inputImage->TransformPhysicalPointToIndex(outPointStart, defIndexStart);
  inputImage->TransformPhysicalPointToIndex(outPointEnd  , defIndexEnd);

  // Compute input image region
  for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim)
    {
    inputRegion.SetIndex(dim, vnl_math_min(defIndexStart[dim], defIndexEnd[dim]));
    inputRegion.SetSize(dim, vnl_math_max(defIndexStart[dim], defIndexEnd[dim]) - inputRegion.GetIndex(dim) + 1);
    }

  // crop the input requested region at the input's largest possible region
  return inputRegion.Crop( inputImage->GetLargestPossibleRegion() );

 }

/*
 * Enlarge the given region to the nearest aligned region.
 * Aligned region = Index and UpperIndex+1 are on the output grid
 */
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::EnlargeToAlignedRegion(RegionType& region)
 {
  for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim)
    {
    // Get corners
    IndexValueType lower = region.GetIndex(dim);
    IndexValueType upper = lower + region.GetSize(dim);

    // Compute deltas between corners and the grid
    const IndexValueType deltaLo = lower % m_OutputGridSize[dim];
    const IndexValueType deltaUp = upper % m_OutputGridSize[dim];

    // Move corners to aligned positions
    lower -= deltaLo;
    if (deltaUp > 0)
      {
      upper += m_OutputGridSize[dim] - deltaUp;
      }

    // Update region
    region.SetIndex(dim, lower);
    region.SetSize(dim, upper - lower);

    }
 }

template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::GenerateOutputInformation()
 {

  Superclass::GenerateOutputInformation();

  //////////////////////////////////////////////////////////////////////////////////////////
  //                            Compute the output image extent
  //////////////////////////////////////////////////////////////////////////////////////////

  // If the output spacing is not specified, we use the first input image as grid reference
  m_OutputSpacing = this->GetInput(0)->GetSignedSpacing();
  m_OutputSpacing[0] *= m_OutputSpacingScale;
  m_OutputSpacing[1] *= m_OutputSpacingScale;
  PointType extentInf, extentSup;
  extentSup.Fill(itk::NumericTraits<double>::max());
  extentInf.Fill(itk::NumericTraits<double>::NonpositiveMin());

  // Compute the extent of each input images and update the global extent
  for (unsigned int imageIndex = 0 ; imageIndex < this->GetNumberOfInputs() ; imageIndex++)
    {
    ImageType * currentImage = static_cast<ImageType *>(
        Superclass::ProcessObject::GetInput(imageIndex) );

    // Update output image extent
    PointType currentInputImageExtentInf, currentInputImageExtentSup;
remi cresson's avatar
remi cresson committed
218
    ImageToExtent(currentImage, currentInputImageExtentInf, currentInputImageExtentSup, this->GetInputReceptiveFields()[imageIndex]);
remi cresson's avatar
remi cresson committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    for(unsigned int dim = 0; dim<ImageType::ImageDimension; ++dim)
      {
      extentInf[dim] = vnl_math_max(currentInputImageExtentInf[dim], extentInf[dim]);
      extentSup[dim] = vnl_math_min(currentInputImageExtentSup[dim], extentSup[dim]);
      }
    }

  // Set final size
  m_OutputSize[0] = vcl_floor( (extentSup[0] - extentInf[0]) / vcl_abs(m_OutputSpacing[0]) ) + 1;
  m_OutputSize[1] = vcl_floor( (extentSup[1] - extentInf[1]) / vcl_abs(m_OutputSpacing[1]) ) + 1;

  // Set final origin
  m_OutputOrigin[0] =  extentInf[0];
  m_OutputOrigin[1] =  extentSup[1];

  // Set output grid size
  if (!m_ForceOutputGridSize)
    {
    // Default is the output field of expression
remi cresson's avatar
remi cresson committed
238
    m_OutputGridSize = this->GetOutputExpressionFields().at(0);
remi cresson's avatar
remi cresson committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    }

  // Resize the largestPossibleRegion to be a multiple of the grid size
  for(unsigned int dim = 0; dim<ImageType::ImageDimension; ++dim)
    {
    if (m_OutputGridSize[dim] > m_OutputSize[dim])
      itkGenericExceptionMacro("Output grid size is larger than output image size !");
    m_OutputSize[dim] -= m_OutputSize[dim] % m_OutputGridSize[dim];
    }

  // Set the largest possible region
  RegionType largestPossibleRegion;
  largestPossibleRegion.SetSize(m_OutputSize);

  //////////////////////////////////////////////////////////////////////////////////////////
  //                  Compute the output number of components per pixel
  //////////////////////////////////////////////////////////////////////////////////////////

  unsigned int outputPixelSize = 0;
  for (auto& protoShape: this->GetOutputTensorsShapes())
    {
260
    // The number of components per pixel is the last dimension of the tensor
remi cresson's avatar
remi cresson committed
261
    int dim_size = protoShape.dim_size();
262
    unsigned int nComponents = 1;
263
    if (1 < dim_size && dim_size <= 4)
remi cresson's avatar
remi cresson committed
264
      {
265
      nComponents = protoShape.dim(dim_size-1).size();
remi cresson's avatar
remi cresson committed
266
      }
267
    else if (dim_size > 4)
remi cresson's avatar
remi cresson committed
268
269
270
271
272
273
274
275
276
277
278
279
280
      {
      itkExceptionMacro("Dim_size=" << dim_size << " currently not supported.");
      }
    outputPixelSize += nComponents;
    }

  // Copy input image projection
  ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(0) );
  const std::string projectionRef = inputImage->GetProjectionRef();

  // Set output image origin/spacing/size/projection
  ImageType * outputPtr = this->GetOutput();
  outputPtr->SetNumberOfComponentsPerPixel(outputPixelSize);
281
282
283
284
  outputPtr->SetProjectionRef        ( projectionRef );
  outputPtr->SetOrigin               ( m_OutputOrigin );
  outputPtr->SetSignedSpacing        ( m_OutputSpacing );
  outputPtr->SetLargestPossibleRegion( largestPossibleRegion );
remi cresson's avatar
remi cresson committed
285

Cresson Remi's avatar
Cresson Remi committed
286
287
288
289
  // Set null pixel
  m_NullPixel.SetSize(outputPtr->GetNumberOfComponentsPerPixel());
  m_NullPixel.Fill(0);

290
291
292
293
294
295
296
  //////////////////////////////////////////////////////////////////////////////////////////
  //                        Set the tiling layout hint in metadata
  //////////////////////////////////////////////////////////////////////////////////////////

  itk::EncapsulateMetaData(outputPtr->GetMetaDataDictionary(), MetaDataKey::TileHintX, m_OutputGridSize[0]);
  itk::EncapsulateMetaData(outputPtr->GetMetaDataDictionary(), MetaDataKey::TileHintY, m_OutputGridSize[1]);

remi cresson's avatar
remi cresson committed
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
 }

template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::GenerateInputRequestedRegion()
 {
  Superclass::GenerateInputRequestedRegion();

  // Output requested region
  RegionType requestedRegion = this->GetOutput()->GetRequestedRegion();

  // First, align the output region
  EnlargeToAlignedRegion(requestedRegion);

  // For each image, get the requested region
  for(unsigned int i = 0; i < this->GetNumberOfInputs(); ++i)
    {
    ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(i) );

    // Compute the requested region
    RegionType inRegion;
    if (!OutputRegionToInputRegion(requestedRegion, inRegion, inputImage) )
      {
      // Image does not overlap requested region: set requested region to null
      itkDebugMacro( <<  "Image #" << i << " :\n" << inRegion << " is outside the requested region");
      inRegion.GetModifiableIndex().Fill(0);
      inRegion.GetModifiableSize().Fill(0);
      }

    // Compute the FOV-scale*FOE radius to pad
remi cresson's avatar
remi cresson committed
328
329
330
    SizeType toPad(this->GetInputReceptiveFields().at(i));
    toPad[0] -= 1 + (this->GetOutputExpressionFields().at(0)[0] - 1) * m_OutputSpacingScale;
    toPad[1] -= 1 + (this->GetOutputExpressionFields().at(0)[1] - 1) * m_OutputSpacingScale;
remi cresson's avatar
remi cresson committed
331
332
333
334
335
336
337

    // Pad with radius
    SmartPad(inRegion, toPad);

    // We need to avoid some extrapolation when mode is patch-based.
    // The reason is that, when some input have a lower spacing than the
    // reference image, the requested region of this lower res input image
338
    // can be one pixel larger when the input image regions are not physically
remi cresson's avatar
remi cresson committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    // aligned.
    if (!m_FullyConvolutional)
      {
      inRegion.PadByRadius(1);
      }

    inRegion.Crop(inputImage->GetLargestPossibleRegion());

    // Update the requested region
    inputImage->SetRequestedRegion(inRegion);

    } // next image

 }

/**
 * Compute the output image
 */
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
::GenerateData()
 {
  // Output pointer and requested region
  typename TOutputImage::Pointer outputPtr = this->GetOutput();
  const RegionType outputReqRegion = outputPtr->GetRequestedRegion();

  // Get the aligned output requested region
  RegionType outputAlignedReqRegion(outputReqRegion);
  EnlargeToAlignedRegion(outputAlignedReqRegion);

  const unsigned int nInputs = this->GetNumberOfInputs();

  // Create input tensors list
remi cresson's avatar
remi cresson committed
373
  DictType inputs;
remi cresson's avatar
remi cresson committed
374
375
376
377
378
379
380
381

  // Populate input tensors
  for (unsigned int i = 0 ; i < nInputs ; i++)
    {
    // Input image pointer
    const ImagePointerType inputPtr = const_cast<TInputImage*>(this->GetInput(i));

    // Patch size of tensor #i
remi cresson's avatar
remi cresson committed
382
    const SizeType inputPatchSize = this->GetInputReceptiveFields().at(i);
remi cresson's avatar
remi cresson committed
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401

    // Input image requested region
    const RegionType reqRegion = inputPtr->GetRequestedRegion();

    if (m_FullyConvolutional)
      {
      // Shape of input tensor #i
      tensorflow::int64 sz_n = 1;
      tensorflow::int64 sz_y = reqRegion.GetSize(1);
      tensorflow::int64 sz_x = reqRegion.GetSize(0);
      tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel();
      tensorflow::TensorShape inputTensorShape({sz_n, sz_y, sz_x, sz_c});

      // Create the input tensor
      tensorflow::Tensor inputTensor(this->GetInputTensorsDataTypes()[i], inputTensorShape);

      // Recopy the whole input
      tf::RecopyImageRegionToTensorWithCast<TInputImage>(inputPtr, reqRegion, inputTensor, 0);

402
      // Input is the tensor representing the subset of image
Cresson Remi's avatar
Cresson Remi committed
403
404
      DictElementType input = { this->GetInputPlaceholders()[i], inputTensor };
      inputs.push_back(input);
remi cresson's avatar
remi cresson committed
405
406
407
      }
    else
      {
Cresson Remi's avatar
Cresson Remi committed
408
      // Preparing patches
remi cresson's avatar
remi cresson committed
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
      // Shape of input tensor #i
      tensorflow::int64 sz_n = outputReqRegion.GetNumberOfPixels();
      tensorflow::int64 sz_y = inputPatchSize[1];
      tensorflow::int64 sz_x = inputPatchSize[0];
      tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel();
      tensorflow::TensorShape inputTensorShape({sz_n, sz_y, sz_x, sz_c});

      // Create the input tensor
      tensorflow::Tensor inputTensor(this->GetInputTensorsDataTypes()[i], inputTensorShape);

      // Fill the input tensor.
      // We iterate over points which are located from the index iterator
      // moving through the output image requested region
      unsigned int elemIndex = 0;
      IndexIteratorType idxIt(outputPtr, outputReqRegion);
      for (idxIt.GoToBegin(); !idxIt.IsAtEnd(); ++idxIt)
        {
        // Get the coordinates of the current output pixel
        PointType point;
        outputPtr->TransformIndexToPhysicalPoint(idxIt.GetIndex(), point);

        // Sample the i-th input patch centered on the point
        tf::SampleCenteredPatch<TInputImage>(inputPtr, point, inputPatchSize, inputTensor, elemIndex);
        elemIndex++;
        }

435
      // Input is the tensor of patches (aka the batch)
Cresson Remi's avatar
Cresson Remi committed
436
437
438
      DictElementType input = { this->GetInputPlaceholders()[i], inputTensor };
      inputs.push_back(input);

remi cresson's avatar
remi cresson committed
439
440
441
442
443
444
445
446
447
448
449
      } // mode is not full convolutional

    } // next input tensor

  // Run session
  TensorListType outputs;
  this->RunSession(inputs, outputs);

  // Fill the output buffer with zero value
  outputPtr->SetBufferedRegion(outputReqRegion);
  outputPtr->Allocate();
Cresson Remi's avatar
Cresson Remi committed
450
  outputPtr->FillBuffer(m_NullPixel);
remi cresson's avatar
remi cresson committed
451
452
453
454
455
456
457

  // Get output tensors
  int bandOffset = 0;
  for (unsigned int i = 0 ; i < outputs.size() ; i++)
    {
    // The offset (i.e. the starting index of the channel for the output tensor) is updated
    // during this call
Cresson Remi's avatar
Cresson Remi committed
458
    // TODO: implement a generic strategy enabling expression field copy in patch-based mode (see tf::CopyTensorToImageRegion)
remi cresson's avatar
remi cresson committed
459
460
461
462
463
464
465
    try
      {
      tf::CopyTensorToImageRegion<TOutputImage> (outputs[i],
          outputAlignedReqRegion, outputPtr, outputReqRegion, bandOffset);
      }
    catch( itk::ExceptionObject & err )
      {
Cresson Remi's avatar
Cresson Remi committed
466
      std::stringstream debugMsg = this->GenerateDebugReport(inputs);
remi cresson's avatar
remi cresson committed
467
468
469
470
      itkExceptionMacro("Error occured during tensor to image conversion.\n"
          << "Context: " << debugMsg.str()
          << "Error:" << err);
      }
remi cresson's avatar
remi cresson committed
471
472
473
474
475
476
477
478
479
    }

 }


} // end namespace otb


#endif