Commit fa0bf983 authored by Cédric Traizet's avatar Cédric Traizet
Browse files

dr application working (monoband output) for autoencoders and tiedautoencoders

No related merge requests found
Showing with 33 additions and 19 deletions
+33 -19
......@@ -194,7 +194,7 @@ private:
otbAppLogINFO("Loading model");
m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"),
MachineLearningModelFactoryType::ReadMode);
otbAppLogINFO("yo");
if (m_Model.IsNull())
{
otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type");
......
......@@ -2,18 +2,20 @@
#define AutoencoderModelFactory_h
#include <shark/Models/TiedAutoencoder.h>
#include <shark/Models/Autoencoder.h>
#include "itkObjectFactoryBase.h"
#include "itkImageIOBase.h"
namespace otb
{
template <class TInputValue, class TTargetValue>
class ITK_EXPORT AutoencoderModelFactory : public itk::ObjectFactoryBase
template <class TInputValue, class TTargetValue, class AutoencoderType>
class ITK_EXPORT AutoencoderModelFactoryBase : public itk::ObjectFactoryBase
{
public:
/** Standard class typedefs. */
typedef AutoencoderModelFactory Self;
typedef AutoencoderModelFactoryBase Self;
typedef itk::ObjectFactoryBase Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
......@@ -26,26 +28,39 @@ public:
itkFactorylessNewMacro(Self);
/** Run-time type information (and related methods). */
itkTypeMacro(AutoencoderModelFactory, itk::ObjectFactoryBase);
itkTypeMacro(AutoencoderModelFactoryBase, itk::ObjectFactoryBase);
/** Register one factory of this type */
static void RegisterOneFactory(void)
{
Pointer AEFactory = AutoencoderModelFactory::New();
Pointer AEFactory = AutoencoderModelFactoryBase::New();
itk::ObjectFactoryBase::RegisterFactory(AEFactory);
}
protected:
AutoencoderModelFactory();
~AutoencoderModelFactory() ITK_OVERRIDE;
AutoencoderModelFactoryBase();
~AutoencoderModelFactoryBase() ITK_OVERRIDE;
private:
AutoencoderModelFactory(const Self &); //purposely not implemented
AutoencoderModelFactoryBase(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
};
template <class TInputValue, class TTargetValue>
class ITK_EXPORT AutoencoderModelFactory : public AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron>> {};
template <class TInputValue, class TTargetValue>
class ITK_EXPORT TiedAutoencoderModelFactory : public AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::TiedAutoencoder< shark::TanhNeuron, shark::LinearNeuron>> {};
} //namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "AutoencoderModelFactory.txx"
#endif
......
......@@ -25,11 +25,10 @@
#include "AutoencoderModel.h"
#include "itkVersion.h"
#include <shark/Models/Autoencoder.h>//normal autoencoder model
namespace otb
{
template <class TInputValue, class TOutputValue>
AutoencoderModelFactory<TInputValue,TOutputValue>::AutoencoderModelFactory()
template <class TInputValue, class TOutputValue, class AutoencoderType>
AutoencoderModelFactoryBase<TInputValue,TOutputValue, AutoencoderType>::AutoencoderModelFactoryBase()
{
std::string classOverride = std::string("otbMachineLearningModel");
......@@ -40,22 +39,22 @@ AutoencoderModelFactory<TInputValue,TOutputValue>::AutoencoderModelFactory()
"Shark RF ML Model",
1,
// itk::CreateObjectFunction<AutoencoderModel<TInputValue,TOutputValue> >::New());
itk::CreateObjectFunction<AutoencoderModel<TInputValue,shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron> > >::New());
itk::CreateObjectFunction<AutoencoderModel<TInputValue,AutoencoderType > >::New());
}
template <class TInputValue, class TOutputValue>
AutoencoderModelFactory<TInputValue,TOutputValue>::~AutoencoderModelFactory()
template <class TInputValue, class TOutputValue, class AutoencoderType>
AutoencoderModelFactoryBase<TInputValue,TOutputValue, AutoencoderType>::~AutoencoderModelFactoryBase()
{
}
template <class TInputValue, class TOutputValue>
const char* AutoencoderModelFactory<TInputValue,TOutputValue>::GetITKSourceVersion(void) const
template <class TInputValue, class TOutputValue, class AutoencoderType>
const char* AutoencoderModelFactoryBase<TInputValue,TOutputValue, AutoencoderType>::GetITKSourceVersion(void) const
{
return ITK_SOURCE_VERSION;
}
template <class TInputValue, class TOutputValue>
const char* AutoencoderModelFactory<TInputValue,TOutputValue>::GetDescription() const
template <class TInputValue, class TOutputValue, class AutoencoderType>
const char* AutoencoderModelFactoryBase<TInputValue,TOutputValue, AutoencoderType>::GetDescription() const
{
return "Autoencoder model factory";
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment