#include "Main.h"

void ImageCluster::Reset(const Bitmap &bmp, const Vec2i &seed)
{
    _seed = seed;
    _coords.FreeMemory();
    AddCoord(bmp, seed);
}

void ImageCluster::AddCoord(const Bitmap &bmp, const Vec2i &coord)
{
    _coords.PushEnd(coord);
}

Vec2i ImageCluster::MassCentroid(const Vec2i &dimensions)
{
    Vec2f center = Vec2f::Origin;
    for(UINT coordIndex = 0; coordIndex < _coords.Length(); coordIndex++)
    {
        center += Vec2f(_coords[coordIndex]);
    }
    center /= float(_coords.Length());
    
    Vec2i bestCoord = _coords[0];
    float bestDistSq = Vec2f::DistSq(Vec2f(_coords[0]), center);
    for(UINT coordIndex = 0; coordIndex < _coords.Length(); coordIndex++)
    {
        float curDistSq = Vec2f::DistSq(Vec2f(_coords[coordIndex]), center);
        if(curDistSq < bestDistSq)
        {
            bestDistSq = curDistSq;
            bestCoord = _coords[coordIndex];
        }
    }
    return bestCoord;
}

Vec2i ImageCluster::ColorCentroid(const Bitmap &bmp, const Vec2i &dimensions)
{
    Vec3f averageColor = Vec3f::Origin;
    for(UINT coordIndex = 0; coordIndex < _coords.Length(); coordIndex++)
    {
        const Vec2i &curCoord = _coords[coordIndex];
        averageColor += Vec3f(bmp[curCoord.y][curCoord.x]);
    }
    averageColor /= float(_coords.Length());

    Vec2i bestCoord = _coords[0];
    float bestDistSq = Vec3f::DistSq(Vec3f(bmp[bestCoord.y][bestCoord.x]), averageColor);
    for(UINT coordIndex = 0; coordIndex < _coords.Length(); coordIndex++)
    {
        const Vec2i &curCoord = _coords[coordIndex];
        float curDistSq = Vec3f::DistSq(Vec3f(bmp[curCoord.y][curCoord.x]), averageColor);
        if(curDistSq < bestDistSq)
        {
            bestDistSq = curDistSq;
            bestCoord = curCoord;
        }
    }
    return bestCoord;
}

double ImageCluster::AssignmentError(const Bitmap &bmp, const Vec2i &coord)
{
    Vec3f pixelColor = Vec3f(bmp[coord.y][coord.x]);
    Vec3f clusterColor = Vec3f(bmp[_seed.y][_seed.x]);
    const double colorError = Vec3f::Dist(pixelColor, clusterColor);

    //const float scale = 1.0f / Math::Max(bmp.Width(), bmp.Height());
    //const double seedDistError = Vec2f::Dist( Vec2f(_seed) * scale, Vec2f(coord) * scale );

    const double sizeError = sqrt(double(_coords.Length()));

    return colorError + sizeError * 0.025;

    //return colorError * colorErrorScale + sizeError * 0.01;
    //return seedDistError;
}

void ImageSegmenter::Segment(const Bitmap &bmp, UINT clusterCount, UINT iterationCount, Grid<UINT> &clusterIDs)
{
    ComponentTimer timer( "Segmenting bitmap, " + String(bmp.Width()) + "x" + String(bmp.Height()) );
    
    _dimensions = Vec2i(bmp.Width(), bmp.Height());
    _assignments.Allocate(_dimensions.y, _dimensions.x);
    _clusters.Allocate(clusterCount);

    _clusterSizeCutoff = bmp.Width() * bmp.Height() / clusterCount / 5;
    
    InitializeClusters(bmp);

    for(UINT iterationIndex = 0; iterationIndex < iterationCount; iterationIndex++)
    {
        //ComponentTimer timer( "Iteration " + String(iterationIndex) );
        GrowClusters(bmp);

        const bool dumpIntermediateResults = false;
        if(dumpIntermediateResults && (iterationIndex % 3 == 0 || iterationIndex == iterationCount - 1))
        {
            Bitmap clusterBmp0, clusterBmp1;
            //DrawClusterIDs(_assignments, clusterBmp0);
            DrawClusterColors(bmp, _assignments, clusterBmp1);
            //clusterBmp0.SavePNG("ClustersIteration" + String(iterationIndex) + ".png");
            clusterBmp1.SavePNG("ColorsIteration" + String(iterationIndex) + ".png");
            String clusterString = "{";
            for(UINT clusterIndex = 0; clusterIndex < _clusters.Length(); clusterIndex++)
            {
                clusterString += String(_clusters[clusterIndex].coords().Length()) + ",";
            }
            clusterString += "}";
            Console::WriteLine(clusterString);
        }

        RecenterClusters(bmp);
    }

    clusterIDs = _assignments;
}

void ImageSegmenter::InitializeClusters(const Bitmap &bmp)
{
    Vector<Vec2i> seeds;
    for(UINT clusterIndex = 0; clusterIndex < _clusters.Length(); clusterIndex++)
    {
        ImageCluster &curCluster = _clusters[clusterIndex];
        Vec2i randomSeed(rand() % _dimensions.x, rand() % _dimensions.y);
        while(seeds.Contains(randomSeed))
        {
            randomSeed = Vec2i(rand() % _dimensions.x, rand() % _dimensions.y);
        }
        curCluster.Reset(bmp, randomSeed);
        seeds.PushEnd(randomSeed);
    }
}

void ImageSegmenter::AssignPixel(const Bitmap &bmp, const Vec2i &coord, UINT clusterIndex)
{
    if(_assignments(coord.y, coord.x) != 0xFFFFFFFF)
    {
        return;
    }
    _assignments(coord.y, coord.x) = clusterIndex;
    _clusters[clusterIndex].AddCoord(bmp, coord);

    const UINT neighborCount = 4;
    const UINT XOffsets[neighborCount] = {-1, 1, 0, 0};
    const UINT YOffsets[neighborCount] = {0, 0, -1, 1};
    for(UINT neighborIndex = 0; neighborIndex < neighborCount; neighborIndex++)
    {
        Vec2i finalCoord(coord.x + XOffsets[neighborIndex], coord.y + YOffsets[neighborIndex]);
        if(_assignments.ValidCoordinates(finalCoord.y, finalCoord.x) && _assignments(finalCoord.y, finalCoord.x) == 0xFFFFFFFF)
        {
            QueueEntry newEntry;
            newEntry.clusterIndex = clusterIndex;
            newEntry.coord = finalCoord;
            newEntry.priority = 1.0 - _clusters[clusterIndex].AssignmentError(bmp, finalCoord);
            _queue.push(newEntry);
        }
    }
}

void ImageSegmenter::GrowClusters(const Bitmap &bmp)
{
    _assignments.Clear(0xFFFFFFFF);

    //
    // Insert all seeds
    //
    for(UINT clusterIndex = 0; clusterIndex < _clusters.Length(); clusterIndex++)
    {
        ImageCluster &curCluster = _clusters[clusterIndex];
        AssignPixel(bmp, curCluster.seed(), clusterIndex);
    }

    while(!_queue.empty())
    {
        QueueEntry curEntry = _queue.top();
        _queue.pop();
        AssignPixel(bmp, curEntry.coord, curEntry.clusterIndex);
    }
}

void ImageSegmenter::RecenterClusters(const Bitmap &bmp)
{
    UINT teleportCount = 0;
    Vector<Vec2i> seeds;
    for(UINT clusterIndex = 0; clusterIndex < _clusters.Length(); clusterIndex++)
    {
        ImageCluster &curCluster = _clusters[clusterIndex];
        Vec2i newSeed;
        if(curCluster.coords().Length() < _clusterSizeCutoff)
        {
            newSeed = Vec2i(rand() % _dimensions.x, rand() % _dimensions.y);
            while(seeds.Contains(newSeed))
            {
                newSeed = Vec2i(rand() % _dimensions.x, rand() % _dimensions.y);
            }
            teleportCount++;
        }
        else
        {
            newSeed = curCluster.ColorCentroid(bmp, _dimensions);
        }
        curCluster.Reset(bmp, newSeed);
        seeds.PushEnd(newSeed);
    }
    //Console::WriteLine(String("Teleport count: ") + String(teleportCount));
}

void ImageSegmenter::DrawClusterIDs(const Grid<UINT> &clusterIDs, Bitmap &bmp)
{
    const UINT clusterCount = clusterIDs.MaxValue() + 1;
    Vector<RGBColor> colors(clusterCount);
    ColorGenerator::Generate(colors);
    bmp.Allocate(clusterIDs.Cols(), clusterIDs.Rows());
    for(UINT y = 0; y < bmp.Height(); y++)
    {
        for(UINT x = 0; x < bmp.Width(); x++)
        {
            bmp[y][x] = colors[clusterIDs(y, x)];
        }
    }
}

void ImageSegmenter::DrawClusterColors(const Bitmap &inputBmp, const Grid<UINT> &clusterIDs, Bitmap &outputBmp)
{
    const UINT clusterCount = clusterIDs.MaxValue() + 1;
    Vector<Vec3f> colorSum(clusterCount, Vec3f::Origin);
    Vector<UINT> colorCount(clusterCount, 0);
    for(UINT y = 0; y < inputBmp.Height(); y++)
    {
        for(UINT x = 0; x < inputBmp.Width(); x++)
        {
            const UINT clusterIndex = clusterIDs(y, x);
            colorSum[clusterIndex] += Vec3f(inputBmp[y][x]);
            colorCount[clusterIndex]++;
        }
    }
    
    Vector<RGBColor> colors(clusterCount);
    for(UINT clusterIndex = 0; clusterIndex < clusterCount; clusterIndex++)
    {
        colors[clusterIndex] = RGBColor(colorSum[clusterIndex] / float(colorCount[clusterIndex]));
    }

    outputBmp.Allocate(clusterIDs.Cols(), clusterIDs.Rows());
    for(UINT y = 0; y < outputBmp.Height(); y++)
    {
        for(UINT x = 0; x < outputBmp.Width(); x++)
        {
            outputBmp[y][x] = colors[clusterIDs(y, x)];
        }
    }
}