/*
App.cpp
Written by Matthew Fisher

App.cpp contains main() and is the source of all application-specific code.
*/

#include "Main.h"

class App
{
public:
    void MakeImageGraphs();
    void MakeAllSegmentations(const String &imageFilename, const String &resultFilename);
    void LoadImageGraphs();
    void CompareAllGraphs();

private:
    void QueryGraph(const Graph &g0, ofstream &file);
    double CompareGraphsDynamicProgramming(const Graph &g0, const Graph &g1, UINT pathLength);

    Vector<Graph*> _graphs;
};

void App::MakeAllSegmentations(const String &imageFilename, const String &resultFilename)
{
    Bitmap bmp;
    bmp.LoadPNG(imageFilename);

    UINT clusterCountList[8] = {20, 50, 100, 250, 500, 1000, 2000, 5000};
    for(UINT clusterCountIndex = 0; clusterCountIndex < 8; clusterCountIndex++)
    {
        ImageSegmenter segmenter;
        Grid<UINT> clusterIDs;
        segmenter.Segment(bmp, clusterCountList[clusterCountIndex], 20, clusterIDs);

        Bitmap clusterBmp0, clusterBmp1;
        ImageSegmenter::DrawClusterIDs(clusterIDs, clusterBmp0);
        ImageSegmenter::DrawClusterColors(bmp, clusterIDs, clusterBmp1);
        clusterBmp0.SavePNG(resultFilename + "_ID_" + String(clusterCountList[clusterCountIndex]) + String(".png"));
        clusterBmp1.SavePNG(resultFilename + "_Color_" + String(clusterCountList[clusterCountIndex]) + String(".png"));
    }
}

void App::MakeImageGraphs()
{
    Console::WriteLine("Converting all images into graphs");
    Directory imageDirectory("Database\\Images");

    const UINT clusterCount = 50;
    const UINT segmentationIterations = 10;

    for(UINT imageIndex = 0; imageIndex < imageDirectory.Files().Length(); imageIndex++)
    {
        const String &filename = imageDirectory.Files()[imageIndex];
        if(filename.EndsWith(".png"))
        {
            Console::WriteLine("Converting " + filename);
            Bitmap bmp;
            bmp.LoadPNG(imageDirectory.DirectoryPath() + filename);

            ImageSegmenter segmenter;
            Grid<UINT> clusterIDs;
            segmenter.Segment(bmp, clusterCount, segmentationIterations, clusterIDs);

            Bitmap clusterBmp0, clusterBmp1;
            ImageSegmenter::DrawClusterIDs(clusterIDs, clusterBmp0);
            ImageSegmenter::DrawClusterColors(bmp, clusterIDs, clusterBmp1);
            clusterBmp0.SavePNG("Database\\Segmentations\\ClustersIteration" + filename);
            clusterBmp1.SavePNG("Database\\Segmentations\\ColorsIteration" + filename);

            Graph g;
            g.LoadFromImageClusters(bmp, clusterIDs);
            g.SaveToFile("Database\\Graphs\\" + filename.RemoveSuffix(".png") + String(".txt"));
        }
    }
}

void App::LoadImageGraphs()
{
    Console::WriteLine("Loading all graphs");
    Directory graphDirectory("Database\\Graphs");

    for(UINT imageIndex = 0; imageIndex < graphDirectory.Files().Length(); imageIndex++)
    {
        const String &filename = graphDirectory.Files()[imageIndex];
        if(filename.EndsWith(".txt"))
        {
            Console::WriteLine("Loading " + filename);
            Graph *newGraph = new Graph;
            newGraph->LoadFromFile(graphDirectory.DirectoryPath() + filename);
            newGraph->Name() = filename.RemoveSuffix(".txt");
            _graphs.PushEnd(newGraph);
        }
    }
}

double App::CompareGraphsDynamicProgramming(const Graph &g0, const Graph &g1, UINT pathLength)
{
    NodeKernelColor nodeKernel;
    EdgeKernelLabeled edgeKernel;
    RootedGraphKernelDynamicProgramming rootedGraphKernel;
    GraphKernelAllRootPairs graphKernel;

    rootedGraphKernel.SetParameters(pathLength, pathLength, nodeKernel, edgeKernel);
    graphKernel.SetParameters(rootedGraphKernel);

    return graphKernel.Evaluate(g0, g1);
}

void App::CompareAllGraphs()
{
    ofstream file("Results.txt");
    for(UINT graphIndex = 0; graphIndex < _graphs.Length(); graphIndex++)
    {
        const Graph &curGraph = *_graphs[graphIndex];
        file << curGraph.Name() << endl;
        QueryGraph(curGraph, file);
    }
}

struct QueryGraphEntry
{
    double value;
    const Graph *g;
};

bool operator < (const QueryGraphEntry &a, const QueryGraphEntry &b)
{
    return (a.value > b.value);
}

void App::QueryGraph(const Graph &g0, ofstream &file)
{
    const UINT pathLength = 2;

    Vector<QueryGraphEntry> allEntries;
    for(UINT graphIndex = 0; graphIndex < _graphs.Length(); graphIndex++)
    {
        const Graph &curGraph = *_graphs[graphIndex];
        
        QueryGraphEntry curEntry;
        curEntry.g = &curGraph;
        curEntry.value = CompareGraphsDynamicProgramming(g0, curGraph, pathLength);
        allEntries.PushEnd(curEntry);

        Console::WriteLine(g0.Name() + ":" + curGraph.Name() + "=" + String(curEntry.value));
    }
    allEntries.Sort();
    for(UINT resultIndex = 0; resultIndex < allEntries.Length(); resultIndex++)
    {
        file << allEntries[resultIndex].g->Name() << '\t';
    }
    file << endl;
    for(UINT resultIndex = 0; resultIndex < allEntries.Length(); resultIndex++)
    {
        file << allEntries[resultIndex].value / allEntries[0].value << '\t';
    }
    file << endl << endl;
}

void main()
{
    App a;
    a.MakeAllSegmentations("Database\\Images\\VanGoghB.png", "Database\\Segmentations\\VanGogh");
    a.MakeAllSegmentations("Database\\Images\\AnimeA.png", "Database\\Segmentations\\Anime");
    a.MakeAllSegmentations("Database\\Images\\NatureD.png", "Database\\Segmentations\\Nature");
    //a.MakeImageGraphs();
    //a.LoadImageGraphs();
    //a.CompareAllGraphs();
}