-
remi cresson authored4bb2789b
Forked from
Cresson Remi / otbtf
1973 commits behind the upstream repository.
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbTensorflowGraphOperations.h"
namespace otb {
namespace tf {
//
// Restore a model from a path
//
void RestoreModel(const std::string path, tensorflow::SavedModelBundle & bundle)
{
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = path;
std::vector<std::pair<std::string, tensorflow::Tensor>> feed_dict =
{{bundle.meta_graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
auto status = bundle.session->Run(feed_dict, {}, {bundle.meta_graph_def.saver_def().restore_op_name()}, nullptr);
if (!status.ok())
{
itkGenericExceptionMacro("Can't restore the input model: " << status.ToString() );
}
}
//
// Restore a model from a path
//
void SaveModel(const std::string path, tensorflow::SavedModelBundle & bundle)
{
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = path;
std::vector<std::pair<std::string, tensorflow::Tensor>> feed_dict =
{{bundle.meta_graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
auto status = bundle.session->Run(feed_dict, {}, {bundle.meta_graph_def.saver_def().save_tensor_name()}, nullptr);
if (!status.ok())
{
itkGenericExceptionMacro("Can't restore the input model: " << status.ToString() );
}
}
//
// Load a session and a graph from a folder
//
void LoadModel(const std::string path, tensorflow::SavedModelBundle & bundle)
{
tensorflow::RunOptions runoptions;
runoptions.set_trace_level(tensorflow::RunOptions_TraceLevel_FULL_TRACE);
auto status = tensorflow::LoadSavedModel(tensorflow::SessionOptions(), runoptions,
path, {tensorflow::kSavedModelTagServe}, &bundle);
if (!status.ok())
{
itkGenericExceptionMacro("Can't load the input model: " << status.ToString() );
}
}
//
// Load a graph from a .meta file
//
tensorflow::GraphDef LoadGraph(std::string filename)
{
tensorflow::MetaGraphDef meta_graph_def;
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
auto status = tensorflow::ReadBinaryProto(tensorflow::Env::Default(), filename, &meta_graph_def);
if (!status.ok())
{
itkGenericExceptionMacro("Can't load the input model: " << status.ToString() );
}
return meta_graph_def.graph_def();
}
//
// Get the following attributes of the specified tensors (by name) of a graph:
// - shape
// - datatype
// Here we assume that the node's output is a tensor
//
void GetTensorAttributes(const tensorflow::GraphDef & graph, std::vector<std::string> & tensorsNames,
std::vector<tensorflow::TensorShapeProto> & shapes, std::vector<tensorflow::DataType> & dataTypes)
{
// Allocation
shapes.clear();
shapes.reserve(tensorsNames.size());
dataTypes.clear();
dataTypes.reserve(tensorsNames.size());
// Get infos
for (std::vector<std::string>::iterator nameIt = tensorsNames.begin();
nameIt != tensorsNames.end(); ++nameIt)
{
bool found = false;
for (int i = 0 ; i < graph.node_size() ; i++)
{
tensorflow::NodeDef node = graph.node(i);
if (node.name().compare((*nameIt)) == 0)
{
found = true;
tensorflow::DataType ts_dt;
// Default (input?) tensor type
auto test_is_output = node.attr().find("T");
if (test_is_output != node.attr().end())
{
ts_dt = node.attr().at("T").type();
}
auto test_has_dtype = node.attr().find("dtype");
if (test_has_dtype != node.attr().end())
{
ts_dt = node.attr().at("dtype").type();
}
auto test_output_type = node.attr().find("output_type");
if (test_output_type != node.attr().end())
{
// if there is an output type, we take it instead of the
// datatype of the input tensor
ts_dt = node.attr().at("output_type").type();
}
dataTypes.push_back(ts_dt);
// Get the tensor's shape
// Here we assure it's a tensor, with 1 shape
tensorflow::TensorShapeProto ts_shp = node.attr().at("_output_shapes").list().shape(0);
shapes.push_back(ts_shp);
}
}
if (!found)
{
itkGenericExceptionMacro("Tensor name \"" << (*nameIt) << "\" not found" );
}
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
}
}
//
// Print a lot of stuff about the specified nodes of the graph
//
void PrintNodeAttributes(const tensorflow::GraphDef & graph, std::vector<std::string> & nodesNames)
{
std::cout << "Go through graph:" << std::endl;
std::cout << "#\tname" << std::endl;
for (int i = 0 ; i < graph.node_size() ; i++)
{
tensorflow::NodeDef node = graph.node(i);
std::cout << i << "\t" << node.name() << std::endl;
for (std::vector<std::string>::iterator nameIt = nodesNames.begin();
nameIt != nodesNames.end(); ++nameIt)
{
if (node.name().compare((*nameIt)) == 0)
{
std::cout << "Node " << i << " : " << std::endl;
std::cout << "\tName: " << node.name() << std::endl;
std::cout << "\tinput_size() : " << node.input_size() << std::endl;
std::cout << "\tPrintDebugString --------------------------------";
std::cout << std::endl;
node.PrintDebugString();
std::cout << "\t-------------------------------------------------" << std::endl;
// display all attributes of the node
std::cout << "\tAttributes of the node: " << std::endl;
for (auto attr = node.attr().begin() ; attr != node.attr().end() ; attr++)
{
std::cout << "\t\tKey :" << attr->first << std::endl;
std::cout << "\t\tValue.value_case() :" << attr->second.value_case() << std::endl;
std::cout << "\t\tPrintDebugString --------------------------------";
std::cout << std::endl;
attr->second.PrintDebugString();
std::cout << "\t\t-------------------------------------------------" << std::endl;
std::cout << std::endl;
} // next attribute
} // node name match
} // next node name
} // next node of the graph
}
} // end namespace tf
} // end namespace otb