#include "lsgrmController.h"

namespace lsgrm
{

template<class TSegmenter>
Controller<TSegmenter>::Controller()
{
  m_TilingMode = LSGRM_TILING_AUTO;
  m_Margin = 0;
  m_NumberOfIterations = 0;
  m_NumberOfFirstIterations = 0;
  m_TileHeight = 0;
  m_TileWidth = 0;
  m_NbTilesX = 0;
  m_NbTilesY = 0;
  m_CleanTemporaryFiles = true;
  m_Threshold = 75;
  m_Memory = 0;
}

template<class TSegmenter>
Controller<TSegmenter>::~Controller()
{
}

template<class TSegmenter>
void Controller<TSegmenter>::RunSegmentation()
{

  if (m_TilingMode == LSGRM_TILING_AUTO || m_TilingMode == LSGRM_TILING_USER)
    {

    const unsigned int numberOfIterationsForPartialSegmentations = 3; // TODO: find a smart value
    unsigned int numberOfIterationsRemaining = m_NumberOfIterations;

    if(m_TilingMode == LSGRM_TILING_AUTO)
      {
      this->GetAutomaticConfiguration();
      }
    else if (m_TilingMode == LSGRM_TILING_USER)
      {
      m_Margin = static_cast<unsigned int>(pow(2, m_NumberOfFirstIterations + 1) - 2);
      }

    std::cout <<
        "--- Configuration: " <<
        "\n\tAvailable RAM: " << m_Memory <<
        "\n\tInput image dimensions: " << m_InputImage->GetLargestPossibleRegion().GetSize() <<
        "\n\tNumber of first iterations: " << m_NumberOfFirstIterations <<
        "\n\tStability margin: " << m_Margin <<
        "\n\tRegular tile size: " << m_TileWidth << " x " << m_TileHeight <<
        "\n\tTiling layout: " << m_NbTilesX << " x " << m_NbTilesY << std::endl;

    // Compute the splitting scheme
    m_Tiles = SplitOTBImage<ImageType>(m_InputImage, m_TileWidth, m_TileHeight, m_Margin,
        m_NbTilesX, m_NbTilesY, m_TemporaryFilesPrefix);

    // Boolean indicating if there are remaining fusions
    bool isFusion = false;

    // Run first partial segmentation
    boost::timer t; t.restart();
    auto accumulatedMemory = RunFirstPartialSegmentation<TSegmenter>(
        m_InputImage,
        m_SpecificParameters,
        m_Threshold,
        m_NumberOfFirstIterations,
        numberOfIterationsForPartialSegmentations,
        m_Tiles,
        m_NbTilesX,
        m_NbTilesY,
        m_TileWidth,
        m_TileHeight,
        isFusion);

    // Update the given number of iterations
    numberOfIterationsRemaining -= m_NumberOfFirstIterations;

    // Gathering useful variables
    GatherUsefulVariables(accumulatedMemory, isFusion);

    // Time monitoring
    ShowTime(t);

    std::cout << "accumulatedMemory=" << accumulatedMemory << std::endl;

    while(accumulatedMemory > m_Memory && isFusion)
      {

      isFusion = false;
      accumulatedMemory = RunPartialSegmentation<TSegmenter>(
          m_SpecificParameters,
          m_Threshold,
          numberOfIterationsForPartialSegmentations,
          m_Tiles,
          m_NbTilesX,
          m_NbTilesY,
          m_InputImage->GetLargestPossibleRegion().GetSize()[0],
          m_InputImage->GetLargestPossibleRegion().GetSize()[1],
          m_InputImage->GetNumberOfComponentsPerPixel(),
          isFusion);

      // Update the given number of iterations
      numberOfIterationsRemaining -= numberOfIterationsForPartialSegmentations;

      // Gathering useful variables
      GatherUsefulVariables(accumulatedMemory, isFusion);

      // Time monitoring
      ShowTime(t);
      }


#ifdef OTB_USE_MPI
    if (otb::MPIConfig::Instance()->GetMyRank() != 0)
      return;
#endif

    if(accumulatedMemory <= m_Memory)
      {
      // Merge all the graphs
      m_LabelImage = MergeAllGraphsAndAchieveSegmentation<TSegmenter>(
          m_SpecificParameters,
          m_Threshold,
          m_Tiles,
          m_NbTilesX,
          m_NbTilesY,
          m_InputImage->GetLargestPossibleRegion().GetSize()[0],
          m_InputImage->GetLargestPossibleRegion().GetSize()[1],
          m_InputImage->GetNumberOfComponentsPerPixel(),
          numberOfIterationsRemaining);

      ShowTime(t);

      }
    else // accumulatedMemory > m_Memory
      {
      // That means there are no more possible fusions but we can not store the output graph
      // Todo do not clean up temporary directory before copying resulting graph to the output directory
      // In the output directory add an info file to give the number of tiles.
      itkExceptionMacro(<< "No more possible fusions, but can not store the output graph");
      }
    }
  else // tiling_mode is none
    {
    // todo use classic grm
    }


}

/*
 * Compute the memory occupied by one node
 */
template<class TSegmenter>
unsigned int Controller<TSegmenter>::GetNodeMemory()
{
  // Create a unique node
  const unsigned int n = 100;
  typename ImageType::Pointer onePixelImage = ImageType::New();
  typename ImageType::IndexType start;
  start.Fill(0);
  typename ImageType::SizeType size;
  size.Fill(n);
  typename ImageType::RegionType region(start, size);
  onePixelImage->SetRegions(region);
  onePixelImage->SetNumberOfComponentsPerPixel(m_InputImage->GetNumberOfComponentsPerPixel());
  onePixelImage->Allocate();
  TSegmenter segmenter;
  segmenter.SetInput(onePixelImage);
  lsrm::GraphOperations<TSegmenter>::InitNodes(onePixelImage,segmenter,FOUR);

  unsigned int memory = segmenter.GetGraphMemory() / (n*n);
  itkDebugMacro(<<"Size of a node is " << memory);
  return memory;
}

/*
 * Compute the maximum number of nodes which can fit in the system memory
 */
template<class TSegmenter>
long unsigned int Controller<TSegmenter>::GetMaximumNumberOfNodesInMemory()
{
  itkDebugMacro(<< "Computing maximum number of nodes in memory");

  m_Memory = getMemorySize();
  assert(m_Memory > 0);

  m_Memory /= 2; // For safety and can prevent out of memory troubles

  return std::ceil(((float) m_Memory) / ((float) GetNodeMemory()));

}

template<class TSegmenter>
void
Controller<TSegmenter>::ComputeMaximumStabilityMargin(unsigned int width,
    unsigned int height, unsigned int &niter, unsigned int &margin)
    {
  itkDebugMacro(<< "Computing maximum stability margin");

  // Compute the stability margin. The naive strategy consider a margin value and a stable size equal.
  niter = 1;
  unsigned int maxMargin = std::min(width, height)/2;
  unsigned int currMargin = static_cast<unsigned int>(pow(2, niter + 1) - 2);
  margin = currMargin;

  while(currMargin < maxMargin)
    {
    margin = currMargin;
    niter++;
    currMargin = static_cast<unsigned int>(pow(2, niter + 1) - 2);
    }
  niter--;

  itkDebugMacro(<< "Number of iterations=" << niter << " margin=" << margin);

    }

/*
 * Compute a tiling layout which minimizes a criterion based on tile compactness
 * and memory usage
 *
 * TODO: use the lsgrmSplitter to truly compute the largest tile of a given layout
 */
template<class TSegmenter>
void Controller<TSegmenter>::GetAutomaticConfiguration()
{

  itkDebugMacro(<<"Get automatic configuration");

  // Compute the maximum number of nodes that can fit the memory
  unsigned long int maximumNumberOfNodesInMemory = GetMaximumNumberOfNodesInMemory();
  itkDebugMacro(<<"Maximum number of nodes in memory is " << maximumNumberOfNodesInMemory);

  // Number of nodes in the entire image
  const unsigned int imageWidth = m_InputImage->GetLargestPossibleRegion().GetSize()[0];
  const unsigned int imageHeight = m_InputImage->GetLargestPossibleRegion().GetSize()[1];
  const unsigned long int nbOfNodesInImage = imageWidth*imageHeight;

  // Default layout: 1x1
  m_NbTilesX = 1;
  m_NbTilesY = 1;

  // Without margins, the number of tiles maximizing memory use
  // is equal to: nbOfNodesInImage / maximumNumberOfNodesInMemory.
  // Actually, there is tile margins. And the best scenario is to have
  // square tiles with margin = width/2, that is tiles 4x larger.
  // Hence the number of tiles maximizing memory use is 4x larger.
  unsigned int minimumNumberOfTiles = std::ceil(4 * nbOfNodesInImage / ((float) maximumNumberOfNodesInMemory));
  itkDebugMacro(<<"Minimum number of tiles is " << minimumNumberOfTiles);

  // In the following steps, we will optimize tiling layout, starting from a number
  // of tiles equal to "minimumNumberOfTiles", up to a number of tiles equal to
  // twice the number of tiles (that is memory usage about 50%)
  unsigned int maximumNumberOfTiles = minimumNumberOfTiles * 4;

  // Search for layout which minimizes the criterion
  // The criterion is the ratio between compactness and memory usage
  // (i.e. tileWidth * tileHeight / maximumNumberOfNodesInMemory)
  itkDebugMacro(<<"Computing layouts properties:");
  float lowestCriterionValue = itk::NumericTraits<float>::max();
  for (unsigned int nbOfTiles = minimumNumberOfTiles ; nbOfTiles <= maximumNumberOfTiles ; nbOfTiles++)
    {
    // Get the multiples of k. For each one, compute the criterion of the tiling
    for (unsigned int layoutNCol = 1; layoutNCol<=nbOfTiles; layoutNCol++)
      {
#ifdef OTB_USE_MPI
      // We want number of tiles which is a multiple of the number of MPI processes
      if (nbOfTiles % layoutNCol == 0 && // Is it a multiple of the nb of ...Tiles?
          nbOfTiles % otb::MPIConfig::Instance()->GetNbProcs() == 0) // ...nProcs?
#else
        if (nbOfTiles % layoutNCol == 0) // Is it a multiple of the nb of Tiles?
#endif
          {
          // Tiling layout
          unsigned int layoutNRow = nbOfTiles / layoutNCol;
          unsigned int tileWidth = imageWidth / layoutNCol;
          unsigned int tileHeight = imageHeight / layoutNRow;

          // Compute margin for regular tiles of this layout
          unsigned int maxMargin, maxIter;
          ComputeMaximumStabilityMargin(tileWidth, tileHeight, maxIter, maxMargin);
          tileWidth += 2*maxMargin;
          tileHeight += 2*maxMargin;

          // Memory use efficiency
          float percentMemory = tileWidth * tileHeight / (float) maximumNumberOfNodesInMemory; // ]0, 1]

          // Compactness
          float perimeter = tileWidth + tileHeight;
          float surface = tileWidth * tileHeight;
          float compactness = perimeter / surface * (float) vcl_max(tileWidth,tileHeight); // [1,+inf]

          // Update minimum criterion
          float criterion = compactness / percentMemory; // ]0, +inf]

          itkDebugMacro(<< std::setprecision (2) << std::fixed
              << "Nb. tiles=" << nbOfTiles
              << " Layout: " << layoutNRow << "x" << layoutNCol
              << " Mem. use=" << percentMemory
              << " Compactness=" << compactness
              << " Criterion=" << criterion
              << " Size (no margin): " << (tileWidth-2*maxMargin)<< "x"<< (tileHeight-2*maxMargin)
              << " Size (with margin): " << tileWidth << "x" << tileHeight
              << " (margin=" << maxMargin << "/nb. iter=" << maxIter << ")" );

          if (criterion < lowestCriterionValue)
            {
            lowestCriterionValue = criterion;
            m_NbTilesX = layoutNCol;
            m_NbTilesY = layoutNRow;
            }
          }
      } // for each multiple of k
    }

  // Compute the tile size
  m_TileWidth = static_cast<unsigned int>(imageWidth/m_NbTilesX);
  m_TileHeight = static_cast<unsigned int>(imageHeight/m_NbTilesY);
  itkDebugMacro(<<"Selected layout: " << m_NbTilesX << "x" << m_NbTilesY
      << " (criterion=" << lowestCriterionValue << ")");

  // Compute the stability margin
  ComputeMaximumStabilityMargin(m_TileWidth, m_TileHeight,m_NumberOfFirstIterations, m_Margin);

  long long unsigned int memoryUsed = GetNodeMemory();
  memoryUsed *= static_cast<long long unsigned int>(m_TileHeight + 2*m_Margin);
  memoryUsed *= static_cast<long long unsigned int>(m_TileWidth + 2*m_Margin);
  itkDebugMacro(<< "An amount of " << memoryUsed/(1024.0*1024.0) << " Mbytes of RAM will be used for regular tiles of size "
      << (m_TileWidth + 2*m_Margin) << "x" << (m_TileHeight + 2*m_Margin) );

}

template <class TSegmenter>
void Controller<TSegmenter>::SetInternalMemoryAvailable(long long unsigned int v) // expecting a value in Mbytes.
{
  assert(v > 0);
  m_Memory = v * 1024ul * 1024ul;
}

template<class TSegmenter>
void Controller<TSegmenter>::SetInputImage(ImageType * inputImage)
{
  m_InputImage = inputImage;
}

template<class TSegmenter>
void Controller<TSegmenter>::SetSpecificParameters(const SegmentationParameterType& params)
{
  m_SpecificParameters = params;
}

template<class TSegmenter>
typename Controller<TSegmenter>::LabelImageType::Pointer
Controller<TSegmenter>::GetLabeledClusteredOutput()
{
  return m_LabelImage;
}
} // end of namespace lsgrm