Commit fa0bf983 authored by Cédric Traizet's avatar Cédric Traizet
Browse files

dr application working (monoband output) for autoencoders and tiedautoencoders

No related merge requests found
Showing with 33 additions and 19 deletions
+33 -19
...@@ -194,7 +194,7 @@ private: ...@@ -194,7 +194,7 @@ private:
otbAppLogINFO("Loading model"); otbAppLogINFO("Loading model");
m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"), m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"),
MachineLearningModelFactoryType::ReadMode); MachineLearningModelFactoryType::ReadMode);
otbAppLogINFO("yo");
if (m_Model.IsNull()) if (m_Model.IsNull())
{ {
otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type"); otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type");
......
...@@ -2,18 +2,20 @@ ...@@ -2,18 +2,20 @@
#define AutoencoderModelFactory_h #define AutoencoderModelFactory_h
#include <shark/Models/TiedAutoencoder.h>
#include <shark/Models/Autoencoder.h>
#include "itkObjectFactoryBase.h" #include "itkObjectFactoryBase.h"
#include "itkImageIOBase.h" #include "itkImageIOBase.h"
namespace otb namespace otb
{ {
template <class TInputValue, class TTargetValue> template <class TInputValue, class TTargetValue, class AutoencoderType>
class ITK_EXPORT AutoencoderModelFactory : public itk::ObjectFactoryBase class ITK_EXPORT AutoencoderModelFactoryBase : public itk::ObjectFactoryBase
{ {
public: public:
/** Standard class typedefs. */ /** Standard class typedefs. */
typedef AutoencoderModelFactory Self; typedef AutoencoderModelFactoryBase Self;
typedef itk::ObjectFactoryBase Superclass; typedef itk::ObjectFactoryBase Superclass;
typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer; typedef itk::SmartPointer<const Self> ConstPointer;
...@@ -26,26 +28,39 @@ public: ...@@ -26,26 +28,39 @@ public:
itkFactorylessNewMacro(Self); itkFactorylessNewMacro(Self);
/** Run-time type information (and related methods). */ /** Run-time type information (and related methods). */
itkTypeMacro(AutoencoderModelFactory, itk::ObjectFactoryBase); itkTypeMacro(AutoencoderModelFactoryBase, itk::ObjectFactoryBase);
/** Register one factory of this type */ /** Register one factory of this type */
static void RegisterOneFactory(void) static void RegisterOneFactory(void)
{ {
Pointer AEFactory = AutoencoderModelFactory::New(); Pointer AEFactory = AutoencoderModelFactoryBase::New();
itk::ObjectFactoryBase::RegisterFactory(AEFactory); itk::ObjectFactoryBase::RegisterFactory(AEFactory);
} }
protected: protected:
AutoencoderModelFactory(); AutoencoderModelFactoryBase();
~AutoencoderModelFactory() ITK_OVERRIDE; ~AutoencoderModelFactoryBase() ITK_OVERRIDE;
private: private:
AutoencoderModelFactory(const Self &); //purposely not implemented AutoencoderModelFactoryBase(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented void operator =(const Self&); //purposely not implemented
}; };
template <class TInputValue, class TTargetValue>
class ITK_EXPORT AutoencoderModelFactory : public AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron>> {};
template <class TInputValue, class TTargetValue>
class ITK_EXPORT TiedAutoencoderModelFactory : public AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::TiedAutoencoder< shark::TanhNeuron, shark::LinearNeuron>> {};
} //namespace otb } //namespace otb
#ifndef OTB_MANUAL_INSTANTIATION #ifndef OTB_MANUAL_INSTANTIATION
#include "AutoencoderModelFactory.txx" #include "AutoencoderModelFactory.txx"
#endif #endif
......
...@@ -25,11 +25,10 @@ ...@@ -25,11 +25,10 @@
#include "AutoencoderModel.h" #include "AutoencoderModel.h"
#include "itkVersion.h" #include "itkVersion.h"
#include <shark/Models/Autoencoder.h>//normal autoencoder model
namespace otb namespace otb
{ {
template <class TInputValue, class TOutputValue> template <class TInputValue, class TOutputValue, class AutoencoderType>
AutoencoderModelFactory<TInputValue,TOutputValue>::AutoencoderModelFactory() AutoencoderModelFactoryBase<TInputValue,TOutputValue, AutoencoderType>::AutoencoderModelFactoryBase()
{ {
std::string classOverride = std::string("otbMachineLearningModel"); std::string classOverride = std::string("otbMachineLearningModel");
...@@ -40,22 +39,22 @@ AutoencoderModelFactory<TInputValue,TOutputValue>::AutoencoderModelFactory() ...@@ -40,22 +39,22 @@ AutoencoderModelFactory<TInputValue,TOutputValue>::AutoencoderModelFactory()
"Shark RF ML Model", "Shark RF ML Model",
1, 1,
// itk::CreateObjectFunction<AutoencoderModel<TInputValue,TOutputValue> >::New()); // itk::CreateObjectFunction<AutoencoderModel<TInputValue,TOutputValue> >::New());
itk::CreateObjectFunction<AutoencoderModel<TInputValue,shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron> > >::New()); itk::CreateObjectFunction<AutoencoderModel<TInputValue,AutoencoderType > >::New());
} }
template <class TInputValue, class TOutputValue> template <class TInputValue, class TOutputValue, class AutoencoderType>
AutoencoderModelFactory<TInputValue,TOutputValue>::~AutoencoderModelFactory() AutoencoderModelFactoryBase<TInputValue,TOutputValue, AutoencoderType>::~AutoencoderModelFactoryBase()
{ {
} }
template <class TInputValue, class TOutputValue> template <class TInputValue, class TOutputValue, class AutoencoderType>
const char* AutoencoderModelFactory<TInputValue,TOutputValue>::GetITKSourceVersion(void) const const char* AutoencoderModelFactoryBase<TInputValue,TOutputValue, AutoencoderType>::GetITKSourceVersion(void) const
{ {
return ITK_SOURCE_VERSION; return ITK_SOURCE_VERSION;
} }
template <class TInputValue, class TOutputValue> template <class TInputValue, class TOutputValue, class AutoencoderType>
const char* AutoencoderModelFactory<TInputValue,TOutputValue>::GetDescription() const const char* AutoencoderModelFactoryBase<TInputValue,TOutputValue, AutoencoderType>::GetDescription() const
{ {
return "Autoencoder model factory"; return "Autoencoder model factory";
} }
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment