diff --git a/include/otbTensorflowMultisourceModelBase.h b/include/otbTensorflowMultisourceModelBase.h index 9c430e68a719c91d56774047db95d689a2b98536..dc025bcb66c48ddafa9fed8ed486c53762ddebe7 100644 --- a/include/otbTensorflowMultisourceModelBase.h +++ b/include/otbTensorflowMultisourceModelBase.h @@ -50,14 +50,14 @@ namespace otb * Target nodes names of the TensorFlow graph that must be triggered can be set * with the SetTargetNodesNames. * - * The OutputTensorNames consists in a strd::vector of std::string, and + * The OutputTensorNames consists in a std::vector of std::string, and * corresponds to the names of tensors that will be computed during the session. * As for input placeholders, output tensors field of expression * (OutputExpressionFields, a std::vector of SizeType), i.e. the output * space that the TensorFlow model will "generate", must be provided. * * Finally, a list of scalar placeholders can be fed in the form of std::vector - * of std::string, each one expressing the assigment of a signle valued + * of std::string, each one expressing the assignment of a single valued * placeholder, e.g. "drop_rate=0.5 learning_rate=0.002 toto=true". * See otb::tf::ExpressionToTensor() to know more about syntax. * diff --git a/include/otbTensorflowMultisourceModelFilter.hxx b/include/otbTensorflowMultisourceModelFilter.hxx index 77902bbdf6c045fe8fa2a9f2136919dd781ff54c..dbed34f753eca958289e41313c9cf34681773981 100644 --- a/include/otbTensorflowMultisourceModelFilter.hxx +++ b/include/otbTensorflowMultisourceModelFilter.hxx @@ -57,10 +57,9 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim) { const SizeValueType psz = patchSize[dim]; - const SizeValueType rval = 0.5 * psz; - const SizeValueType lval = psz - rval; + const SizeValueType lval = 0.5 * psz; region.GetModifiableIndex()[dim] += lval; - region.GetModifiableSize()[dim] -= psz; + region.GetModifiableSize()[dim] -= psz - 1; } } @@ -327,8 +326,19 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> // Compute the FOV-scale*FOE radius to pad SizeType toPad(this->GetInputReceptiveFields().at(i)); - toPad[0] -= 1 + (this->GetOutputExpressionFields().at(0)[0] - 1) * m_OutputSpacingScale; - toPad[1] -= 1 + (this->GetOutputExpressionFields().at(0)[1] - 1) * m_OutputSpacingScale; + for(unsigned int dim = 0; dim<ImageType::ImageDimension; ++dim) + { + int valToPad = 1 + (this->GetOutputExpressionFields().at(0)[dim] - 1) * m_OutputSpacingScale * this->GetInput(0)->GetSpacing()[dim] / this->GetInput(i)->GetSpacing()[dim] ; + if (valToPad > toPad[dim]) + itkExceptionMacro("The input requested region of source #" << i << " is not consistent (dim "<< dim<< ")." << + "Please check RF, EF, SF vs physical spacing of your image!" << + "\nReceptive field: " << this->GetInputReceptiveFields().at(i)[dim] << + "\nExpression field: " << this->GetOutputExpressionFields().at(0)[dim] << + "\nScale factor: " << m_OutputSpacingScale << + "\nReference image spacing: " << this->GetInput(0)->GetSpacing()[dim] << + "\nImage " << i << " spacing: " << this->GetInput(i)->GetSpacing()[dim]); + toPad[dim] -= valToPad; + } // Pad with radius SmartPad(inRegion, toPad); diff --git a/include/otbTensorflowMultisourceModelLearningBase.h b/include/otbTensorflowMultisourceModelLearningBase.h index 930b8366083355a715cc9910b3068384867a2f88..f5ada7f25c2a8a1859818887e8cc0f7463924492 100644 --- a/include/otbTensorflowMultisourceModelLearningBase.h +++ b/include/otbTensorflowMultisourceModelLearningBase.h @@ -23,7 +23,15 @@ namespace otb /** * \class TensorflowMultisourceModelLearningBase - * \brief This filter is the base class for all learning filters. + * \brief This filter is the base class for all filters that input patches images. + * + * One input patches image consist in an image of size (pszx, pszy*n, nbands) where: + * -pszx : is the width of one patch + * -pszy : is the height of one patch + * -n : is the number of patches in the patches image + * -nbands : is the number of channels in the patches image + * + * This filter verify that every patches images are consistent. * * The batch size can be set using the SetBatchSize() method. * The streaming can be activated to allow the processing of huge datasets. diff --git a/python/create_savedmodel_pxs_fcn.py b/python/create_savedmodel_pxs_fcn.py new file mode 100755 index 0000000000000000000000000000000000000000..bb57c18eaface0435bd56237bd4af72246d1df0d --- /dev/null +++ b/python/create_savedmodel_pxs_fcn.py @@ -0,0 +1,74 @@ +from tricks import * +import sys +import os + +nclasses=8 + +def myModel(x1,x2): + + # The XS branch (input patches: 8x8x4) + conv1_x1 = tf.layers.conv2d(inputs=x1, filters=16, kernel_size=[5,5], padding="valid", + activation=tf.nn.relu) # out size: 4x4x16 + conv2_x1 = tf.layers.conv2d(inputs=conv1_x1, filters=32, kernel_size=[3,3], padding="valid", + activation=tf.nn.relu) # out size: 2x2x32 + conv3_x1 = tf.layers.conv2d(inputs=conv2_x1, filters=64, kernel_size=[2,2], padding="valid", + activation=tf.nn.relu) # out size: 1x1x64 + + # The PAN branch (input patches: 32x32x1) + conv1_x2 = tf.layers.conv2d(inputs=x2, filters=16, kernel_size=[5,5], padding="valid", + activation=tf.nn.relu) # out size: 28x28x16 + pool1_x2 = tf.layers.max_pooling2d(inputs=conv1_x2, pool_size=[2, 2], + strides=2) # out size: 14x14x16 + conv2_x2 = tf.layers.conv2d(inputs=pool1_x2, filters=32, kernel_size=[5,5], padding="valid", + activation=tf.nn.relu) # out size: 10x10x32 + pool2_x2 = tf.layers.max_pooling2d(inputs=conv2_x2, pool_size=[2, 2], + strides=2) # out size: 5x5x32 + conv3_x2 = tf.layers.conv2d(inputs=pool2_x2, filters=64, kernel_size=[3,3], padding="valid", + activation=tf.nn.relu) # out size: 3x3x64 + conv4_x2 = tf.layers.conv2d(inputs=conv3_x2, filters=64, kernel_size=[3,3], padding="valid", + activation=tf.nn.relu) # out size: 1x1x64 + + # Stack features + features = tf.reshape(tf.stack([conv3_x1, conv4_x2], axis=3), + shape=[-1, 128], name="features") + + # 8 neurons for 8 classes + estimated = tf.layers.dense(inputs=features, units=nclasses, activation=None) + estimated_label = tf.argmax(estimated, 1, name="prediction") + + return estimated, estimated_label + +""" Main """ +# check number of arguments +if len(sys.argv) != 2: + print("Usage : <output directory for SavedModel>") + sys.exit(1) + +# Create the graph +with tf.Graph().as_default(): + + # Placeholders + x1 = tf.placeholder(tf.float32, [None, None, None, 4], name="x1") + x2 = tf.placeholder(tf.float32, [None, None, None, 1], name="x2") + y = tf.placeholder(tf.int32 , [None, None, None, 1], name="y") + lr = tf.placeholder_with_default(tf.constant(0.0002, dtype=tf.float32, shape=[]), + shape=[], name="lr") + + # Output + y_estimated, y_label = myModel(x1,x2) + + # Loss function + cost = tf.losses.sparse_softmax_cross_entropy(labels=tf.reshape(y, [-1, 1]), + logits=tf.reshape(y_estimated, [-1, nclasses])) + + # Optimizer + optimizer = tf.train.AdamOptimizer(learning_rate=lr, name="optimizer").minimize(cost) + + # Initializer, saver, session + init = tf.global_variables_initializer() + saver = tf.train.Saver( max_to_keep=20 ) + sess = tf.Session() + sess.run(init) + + # Create a SavedModel + CreateSavedModel(sess, ["x1:0", "x2:0", "y:0"], ["features:0", "prediction:0"], sys.argv[1]) diff --git a/python/create_savedmodel_simple_cnn.py b/python/create_savedmodel_simple_cnn.py new file mode 100755 index 0000000000000000000000000000000000000000..2cd79f1884de4028f050c6e0ac2a657446f0ed38 --- /dev/null +++ b/python/create_savedmodel_simple_cnn.py @@ -0,0 +1,59 @@ +from tricks import * +import sys +import os + +nclasses=8 + +def myModel(x): + + # input patches: 16x16x4 + conv1 = tf.layers.conv2d(inputs=x, filters=16, kernel_size=[5,5], padding="valid", + activation=tf.nn.relu) # out size: 12x12x16 + pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) # out: 6x6x16 + conv2 = tf.layers.conv2d(inputs=pool1, filters=16, kernel_size=[3,3], padding="valid", + activation=tf.nn.relu) # out size: 4x4x16 + pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) # out: 2x2x16 + conv3 = tf.layers.conv2d(inputs=pool2, filters=32, kernel_size=[2,2], padding="valid", + activation=tf.nn.relu) # out size: 1x1x32 + + # Features + features = tf.reshape(conv3, shape=[-1, 32], name="features") + + # 8 neurons for 8 classes + estimated = tf.layers.dense(inputs=features, units=nclasses, activation=None) + estimated_label = tf.argmax(estimated, 1, name="prediction") + + return estimated, estimated_label + +""" Main """ +if len(sys.argv) != 2: + print("Usage : <output directory for SavedModel>") + sys.exit(1) + +# Create the TensorFlow graph +with tf.Graph().as_default(): + + # Placeholders + x = tf.placeholder(tf.float32, [None, None, None, 4], name="x") + y = tf.placeholder(tf.int32 , [None, None, None, 1], name="y") + lr = tf.placeholder_with_default(tf.constant(0.0002, dtype=tf.float32, shape=[]), + shape=[], name="lr") + + # Output + y_estimated, y_label = myModel(x) + + # Loss function + cost = tf.losses.sparse_softmax_cross_entropy(labels=tf.reshape(y, [-1, 1]), + logits=tf.reshape(y_estimated, [-1, nclasses])) + + # Optimizer + optimizer = tf.train.AdamOptimizer(learning_rate=lr, name="optimizer").minimize(cost) + + # Initializer, saver, session + init = tf.global_variables_initializer() + saver = tf.train.Saver( max_to_keep=20 ) + sess = tf.Session() + sess.run(init) + + # Create a SavedModel + CreateSavedModel(sess, ["x:0", "y:0"], ["features:0", "prediction:0"], sys.argv[1]) diff --git a/python/create_savedmodel_simple_fcn.py b/python/create_savedmodel_simple_fcn.py new file mode 100755 index 0000000000000000000000000000000000000000..53f38502ff0acaa2cd396ad174bb1ed83f334992 --- /dev/null +++ b/python/create_savedmodel_simple_fcn.py @@ -0,0 +1,59 @@ +from tricks import * +import sys +import os + +nclasses=8 + +def myModel(x): + + # input patches: 16x16x4 + conv1 = tf.layers.conv2d(inputs=x, filters=16, kernel_size=[5,5], padding="valid", + activation=tf.nn.relu) # out size: 12x12x16 + conv2 = tf.layers.conv2d(inputs=conv1, filters=16, kernel_size=[5,5], padding="valid", + activation=tf.nn.relu) # out size: 8x8x16 + conv3 = tf.layers.conv2d(inputs=conv2, filters=32, kernel_size=[5,5], padding="valid", + activation=tf.nn.relu) # out size: 4x4x32 + conv4 = tf.layers.conv2d(inputs=conv3, filters=32, kernel_size=[4,4], padding="valid", + activation=tf.nn.relu) # out size: 1x1x32 + + # Features + features = tf.reshape(conv4, shape=[-1, 32], name="features") + + # 8 neurons for 8 classes + estimated = tf.layers.dense(inputs=features, units=nclasses, activation=None) + estimated_label = tf.argmax(estimated, 1, name="prediction") + + return estimated, estimated_label + +""" Main """ +if len(sys.argv) != 2: + print("Usage : <output directory for SavedModel>") + sys.exit(1) + +# Create the TensorFlow graph +with tf.Graph().as_default(): + + # Placeholders + x = tf.placeholder(tf.float32, [None, None, None, 4], name="x") + y = tf.placeholder(tf.int32 , [None, None, None, 1], name="y") + lr = tf.placeholder_with_default(tf.constant(0.0002, dtype=tf.float32, shape=[]), + shape=[], name="lr") + + # Output + y_estimated, y_label = myModel(x) + + # Loss function + cost = tf.losses.sparse_softmax_cross_entropy(labels=tf.reshape(y, [-1, 1]), + logits=tf.reshape(y_estimated, [-1, nclasses])) + + # Optimizer + optimizer = tf.train.AdamOptimizer(learning_rate=lr, name="optimizer").minimize(cost) + + # Initializer, saver, session + init = tf.global_variables_initializer() + saver = tf.train.Saver( max_to_keep=20 ) + sess = tf.Session() + sess.run(init) + + # Create a SavedModel + CreateSavedModel(sess, ["x:0", "y:0"], ["features:0", "prediction:0"], sys.argv[1])