Commit 02c1eabc authored by Narcon Nicolas's avatar Narcon Nicolas
Browse files

COMP: some fixes

Showing with 7 additions and 6 deletions
+7 -6
...@@ -253,7 +253,7 @@ public: ...@@ -253,7 +253,7 @@ public:
// Setup filter // Setup filter
m_TFFilter = TFModelFilterType::New(); m_TFFilter = TFModelFilterType::New();
m_TFFilter->SetSavedModel(*m_SavedModel); m_TFFilter->SetSavedModel(m_SavedModel);
m_TFFilter->SetOutputTensors(GetParameterStringList("output.names")); m_TFFilter->SetOutputTensors(GetParameterStringList("output.names"));
m_TFFilter->SetOutputSpacingScale(GetParameterFloat("output.spcscale")); m_TFFilter->SetOutputSpacingScale(GetParameterFloat("output.spcscale"));
otbAppLogINFO("Output spacing ratio: " << m_TFFilter->GetOutputSpacingScale()); otbAppLogINFO("Output spacing ratio: " << m_TFFilter->GetOutputSpacingScale());
......
...@@ -82,7 +82,7 @@ tensorflow::GraphDef LoadGraph(std::string filename) ...@@ -82,7 +82,7 @@ tensorflow::GraphDef LoadGraph(std::string filename)
// Get the following attributes of the specified tensors (by name) of a graph: // Get the following attributes of the specified tensors (by name) of a graph:
// - shape // - shape
// - datatype // - datatype
void GetTensorAttributes(const map<string, TensorInfo> layers, std::vector<std::string> & tensorsNames, void GetTensorAttributes(const std::map<std::string, tensorflow::TensorInfo> layers, std::vector<std::string> & tensorsNames,
std::vector<tensorflow::TensorShapeProto> & shapes, std::vector<tensorflow::DataType> & dataTypes) std::vector<tensorflow::TensorShapeProto> & shapes, std::vector<tensorflow::DataType> & dataTypes)
{ {
// Allocation // Allocation
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
// Tensorflow // Tensorflow
#include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/cc/saved_model/signature_constants.h"
// Tensorflow helpers // Tensorflow helpers
#include "otbTensorflowGraphOperations.h" #include "otbTensorflowGraphOperations.h"
...@@ -103,7 +104,7 @@ public: ...@@ -103,7 +104,7 @@ public:
typedef std::vector<tensorflow::Tensor> TensorListType; typedef std::vector<tensorflow::Tensor> TensorListType;
/** Set and Get the Tensorflow session and graph */ /** Set and Get the Tensorflow session and graph */
void SetSaveModel(tensorflow::SavedModelBundle * saved_model;) { m_SavedModel = saved_model; } void SetSavedModel(tensorflow::SavedModelBundle * saved_model) { m_SavedModel = saved_model; }
tensorflow::SavedModelBundle * GetSavedModel() { return m_SavedModel; } tensorflow::SavedModelBundle * GetSavedModel() { return m_SavedModel; }
void SearchAndSetSignatureDef(const tensorflow::protobuf::Map<std::string, tensorflow::SignatureDef> signatures) void SearchAndSetSignatureDef(const tensorflow::protobuf::Map<std::string, tensorflow::SignatureDef> signatures)
......
...@@ -173,7 +173,7 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage> ...@@ -173,7 +173,7 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////
// Set all subelement of the model // Set all subelement of the model
auto signatures = this->GetSavedModel()->GetSignatures(); auto signatures = this->GetSavedModel()->GetSignatures();
signaturedef = this->SearchAndSetSignatureDef(signatures); auto signaturedef = this->SearchAndSetSignatureDef(signatures);
for (auto& output: signaturedef.outputs()) for (auto& output: signaturedef.outputs())
{ {
std::string userName = output.first.substr(0, output.first.find(":")); std::string userName = output.first.substr(0, output.first.find(":"));
...@@ -192,8 +192,8 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage> ...@@ -192,8 +192,8 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
} }
// Get input and output tensors datatypes and shapes // Get input and output tensors datatypes and shapes
tf::GetInputAttributes(this->GetSignatureDef(), m_InputPlaceholders, m_InputTensorsShapes, m_InputTensorsDataTypes); tf::GetTensorAttributes(signaturedef.inputs(), m_InputPlaceholders, m_InputTensorsShapes, m_InputTensorsDataTypes);
tf::GetOutputAttributes(this->GetSignatureDef(), m_OutputTensors, m_OutputTensorsShapes, m_OutputTensorsDataTypes); tf::GetTensorAttributes(signaturedef.outputs(), m_OutputTensors, m_OutputTensorsShapes, m_OutputTensorsDataTypes);
} }
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment