otbTrainClassifierFromDeepFeatures.cxx 4.64 KB
Newer Older
remi cresson's avatar
remi cresson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
/*=========================================================================

  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#include "itkFixedArray.h"
#include "itkObjectFactory.h"

// Elevation handler
#include "otbWrapperElevationParametersHandler.h"
#include "otbWrapperApplicationFactory.h"
#include "otbWrapperCompositeApplication.h"

// Application engine
#include "otbStandardFilterWatcher.h"
#include "itkFixedArray.h"

// TF (used to get the environment variable for the number of inputs)
#include "otbTensorflowCommon.h"

namespace otb
{

namespace Wrapper
{

class TrainClassifierFromDeepFeatures : public CompositeApplication
{
public:
  /** Standard class typedefs. */
  typedef TrainClassifierFromDeepFeatures              Self;
  typedef Application                         Superclass;
  typedef itk::SmartPointer<Self>             Pointer;
  typedef itk::SmartPointer<const Self>       ConstPointer;

  /** Standard macro */
  itkNewMacro(Self);
  itkTypeMacro(TrainClassifierFromDeepFeatures, otb::Wrapper::CompositeApplication);

private:

  //
  // Add an input source, which includes:
  // -an input image list
  // -an input patchsize (dimensions of samples)
  //
  void AddAnInputImage(int inputNumber = 0)
  {
    inputNumber++;

    // Create keys and descriptions
    std::stringstream ss_key_group, ss_desc_group;
    ss_key_group << "source" << inputNumber;
    ss_desc_group << "Parameters for source " << inputNumber;

    // Populate group
    ShareParameter(ss_key_group.str(), "tfmodel." + ss_key_group.str(), ss_desc_group.str());

  }

  void DoInit()
  {

  SetName("TrainClassifierFromDeepFeatures");
  SetDescription("Train a classifier from deep net based features of an image and training vector data.");

  // Documentation
  SetDocLongDescription("See TrainImagesClassifier application");
  SetDocLimitations("None");
  SetDocAuthors("Remi Cresson");
  SetDocSeeAlso(" ");

  ClearApplications();

  // Add applications
Cresson Remi's avatar
Cresson Remi committed
81
82
  AddApplication("TrainImagesClassifier",  "train",   "Train images classifier");
  AddApplication("TensorflowModelServe",   "tfmodel", "Serve the TF model");
remi cresson's avatar
remi cresson committed
83
84
85
86
87
88
89

  // Model shared parameters
  AddAnInputImage();
  for (int i = 1; i < tf::GetNumberOfSources() ; i++)
  {
    AddAnInputImage(i);
  }
Cresson Remi's avatar
Cresson Remi committed
90
91
92
  ShareParameter("model",      "tfmodel.model",       "Deep net inputs parameters",   "Parameters of the deep net inputs: placeholder names, receptive fields, etc.");
  ShareParameter("output",     "tfmodel.output",      "Deep net outputs parameters",  "Parameters of the deep net outputs: tensors names, expression fields, etc.");
  ShareParameter("optim",      "tfmodel.optim",       "Processing time optimization", "This group of parameters allows optimization of processing time");
remi cresson's avatar
remi cresson committed
93
94

  // Train shared parameters
95
  ShareParameter("ram",        "train.ram",           "Available RAM (Mb)",           "Available RAM (Mb)");
Cresson Remi's avatar
Cresson Remi committed
96
97
98
99
  ShareParameter("vd",         "train.io.vd",         "Vector data for training",     "Input vector data for training");
  ShareParameter("valid",      "train.io.valid",      "Vector data for validation",   "Input vector data for validation");
  ShareParameter("out",        "train.io.out",        "Output classification model",  "Output classification model");
  ShareParameter("confmatout", "train.io.confmatout", "Output confusion matrix",      "Output confusion matrix of the classification model");
remi cresson's avatar
remi cresson committed
100
101

  // Shared parameter groups
Cresson Remi's avatar
Cresson Remi committed
102
103
104
105
  ShareParameter("sample",     "train.sample",        "Sampling parameters" ,         "Training and validation samples parameters" );
  ShareParameter("elev",       "train.elev",          "Elevation parameters",         "Elevation parameters" );
  ShareParameter("classifier", "train.classifier",    "Classifier parameters",        "Classifier parameters" );
  ShareParameter("rand",       "train.rand",          "User defined random seed",     "User defined random seed" );
remi cresson's avatar
remi cresson committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

  }


  void DoUpdateParameters()
  {
    UpdateInternalParameters("train");
  }

  void DoExecute()
  {
    ExecuteInternal("tfmodel");
    GetInternalApplication("train")->AddImageToParameterInputImageList("io.il", GetInternalApplication("tfmodel")->GetParameterOutputImage("out"));
    UpdateInternalParameters("train");
    ExecuteInternal("train");
  }   // DOExecute()

  void AfterExecuteAndWriteOutputs()
  {
    // Nothing to do
  }

};
Cresson Remi's avatar
Cresson Remi committed
129
130
} // namespace Wrapper
} // namespace otb
remi cresson's avatar
remi cresson committed
131
132

OTB_APPLICATION_EXPORT( otb::Wrapper::TrainClassifierFromDeepFeatures )