Commit 592e878f authored by Narcon Nicolas's avatar Narcon Nicolas
Browse files

WIP: fix (?) default value for model.tagsets

parent 98193d5f
2 merge requests!11Apply clang-format to cpp files,!3Migration to TensorFlow2
Showing with 11 additions and 4 deletions
+11 -4
...@@ -177,7 +177,7 @@ public: ...@@ -177,7 +177,7 @@ public:
AddParameter(ParameterType_Bool, "model.fullyconv", "Fully convolutional"); AddParameter(ParameterType_Bool, "model.fullyconv", "Fully convolutional");
MandatoryOff ("model.fullyconv"); MandatoryOff ("model.fullyconv");
AddParameter(ParameterType_StringList, "model.tagsets", "Which tags (i.e. v1.MetaGraphDefs) to load from the saved model. Can be retrieved by running `saved_model_cli show --dir your_model_dir --all`"); AddParameter(ParameterType_StringList, "model.tagsets", "Which tags (i.e. v1.MetaGraphDefs) to load from the saved model. Can be retrieved by running `saved_model_cli show --dir your_model_dir --all`");
SetDefaultParameterStringList ("model.tagsets", {tensorflow::kSavedModelTagServe}) MandatoryOff ("model.tagsets");
// Output tensors parameters // Output tensors parameters
AddParameter(ParameterType_Group, "output", "Output tensors parameters"); AddParameter(ParameterType_Group, "output", "Output tensors parameters");
...@@ -249,8 +249,15 @@ public: ...@@ -249,8 +249,15 @@ public:
{ {
// Load the Tensorflow bundle // Load the Tensorflow bundle
std::vector<std::string> tagList = GetParameterStringList("model.tagsets"); if (HasUserValue("model.tagsets")){
std::unordered_set<std::string> tagSets(tagList.begin(), tagList.end()); std::vector<std::string> tagList = GetParameterStringList("model.tagsets");
std::unordered_set<std::string> tagSets(tagList.begin(), tagList.end()); // convert to unordered_set
}else{
std::unordered_set<std::string> tagSets = {tensorflow::kSavedModelTagServe};
}
tf::LoadModel(GetParameterAsString("model.dir"), m_SavedModel, tagSets); tf::LoadModel(GetParameterAsString("model.dir"), m_SavedModel, tagSets);
// Prepare inputs // Prepare inputs
......
...@@ -49,7 +49,7 @@ void SaveModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bu ...@@ -49,7 +49,7 @@ void SaveModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bu
// //
// Load a session and a graph from a folder // Load a session and a graph from a folder
// //
void LoadModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle, std::list<std::string> tagsets) void LoadModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle, std::unordered_set<std::string> tagsets)
{ {
tensorflow::RunOptions runoptions; tensorflow::RunOptions runoptions;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment