#include "Main.h"

Graph::Graph()
{

}

Graph::Graph(const Graph &g)
{
    _nodes = g._nodes;
    _edges = g._edges;
    for(UINT edgeIndex = 0; edgeIndex < _edges.Length(); edgeIndex++)
    {
        Edge &curEdge = _edges[edgeIndex];
        curEdge.n0 = &_nodes[curEdge.n0->index];
        curEdge.n1 = &_nodes[curEdge.n1->index];
    }
    Finalize();
}

Graph::~Graph()
{
    FreeMemory();
}

void Graph::FreeMemory()
{
    _edges.FreeMemory();
    _nodes.FreeMemory();
}

Edge* Graph::FindEdge(UINT n0Index, UINT n1Index)
{
    for(UINT edgeIndex = 0; edgeIndex < _edges.Length(); edgeIndex++)
    {
        Edge *curEdge = &_edges[edgeIndex];
        if( (curEdge->n0->index == n0Index && curEdge->n1->index == n1Index) ||
            (curEdge->n0->index == n1Index && curEdge->n1->index == n0Index))
        {
            return curEdge;
        }
    }
    return NULL;
}

void Graph::LoadFromImageClusters(const Bitmap &bmp, const Grid<UINT> &clusterIDs)
{
    FreeMemory();
    
    const UINT clusterCount = clusterIDs.MaxValue() + 1;
    _nodes.Allocate(clusterCount);

    const UINT width = clusterIDs.Cols();
    const UINT height = clusterIDs.Rows();

    Vector<Vec3f> colorSum(clusterCount, Vec3f::Origin);
    Vector<UINT> colorCount(clusterCount, 0);
    for(UINT y = 0; y < height; y++)
    {
        for(UINT x = 0; x < width; x++)
        {
            const UINT clusterIndex = clusterIDs(y, x);
            colorSum[clusterIndex] += Vec3f(bmp[y][x]);
            colorCount[clusterIndex]++;
        }
    }
    
    for(UINT clusterIndex = 0; clusterIndex < clusterCount; clusterIndex++)
    {
        _nodes[clusterIndex].color = RGBColor(colorSum[clusterIndex] / float(colorCount[clusterIndex]));
    }

    set<UINT64> insertedEdges;
    for(UINT y = 0; y < height - 1; y++)
    {
        for(UINT x = 0; x < width - 1; x++)
        {
            UINT baseClusterIndex = clusterIDs(y, x);
            for(UINT neighborIndex = 0; neighborIndex < 2; neighborIndex++)
            {
                UINT otherClusterIndex;
                if(neighborIndex == 0) otherClusterIndex = clusterIDs(y, x + 1);
                else otherClusterIndex = clusterIDs(y + 1, x);
                if(otherClusterIndex != baseClusterIndex)
                {
                    UINT minClusterIndex = Math::Min(baseClusterIndex, otherClusterIndex);
                    UINT maxClusterIndex = Math::Max(baseClusterIndex, otherClusterIndex);
                    UINT64 indexHash = (UINT64(minClusterIndex) << 32) + UINT64(maxClusterIndex);
                    if(insertedEdges.count(indexHash) == 0)
                    {
                        insertedEdges.insert(indexHash);
                        Edge newEdge;
                        newEdge.n0 = &_nodes[minClusterIndex];
                        newEdge.n1 = &_nodes[maxClusterIndex];
                        //newEdge.label = 0;
                        _edges.PushEnd(newEdge);
                    }
                }
            }
        }
    }

    Finalize();
}

void Graph::EnumerateAllWalks(UINT length, const Node &start, Vector<Walk*> &walks) const
{
    if (length == 0)
    {
        walks.Allocate(1);
        walks[0] = new Walk(&start);
    }
    else
    {
        Vector<Walk*> childWalks;
        for(UINT edgeIndex = 0; edgeIndex < start.edges.Length(); edgeIndex++)
        {
            EnumerateAllWalks(length - 1, *(start.edges[edgeIndex]->GetOtherNode(&start)), childWalks);
            for(UINT walkIndex = 0; walkIndex < childWalks.Length(); walkIndex++)
            {
                Walk *curWalk = childWalks[walkIndex];
                walks.PushEnd(new Walk(curWalk, &start));
            }
            childWalks.DeleteMemory();
        }
    }
}

void Graph::LoadFromFile(const String &filename)
{
    FreeMemory();
    Vector<Graph*> graphs(1);
    graphs[0] = this;
    LoadGraphsFromFile(filename, graphs);
    PersistentAssert(graphs.Length() == 1, "More than one graph found in file");
}

void Graph::LoadGraphsFromFile(const String &filename, Vector<Graph*> &graphs)
{
    Vector<String> lines, words;
    Utility::GetFileLines(filename, lines);
    
    const String &graphCountLine = lines[0];
    graphCountLine.Partition(' ', words);

    Assert(words.Length() == 2 && words[0] == "GraphCount", String("Invalid line encountered in ") + filename);
    UINT graphCount = words[1].ConvertToUnsignedInteger();

    if(graphs.Length() != graphCount)
    {
        graphs.Allocate(graphCount, NULL);

        for(UINT graphIndex = 0; graphIndex < graphCount; graphIndex++)
        {
            graphs[graphIndex] = new Graph;
        }
    }

    Graph *activeGraph = NULL;

    for(UINT lineIndex = 1; lineIndex < lines.Length(); lineIndex++)
    {
        const String &curLine = lines[lineIndex];
        if(curLine.Length() > 2)
        {
            curLine.Partition(' ', words);
            Assert(words.Length() >= 4 && words[0].Length() == 1, "Invalid line");
            
            char identifier = words[0][0];
            UINT objectIndex = words[1].ConvertToUnsignedInteger();
            const String &prop = words[2];
            const String &propValue = words[3];

            switch(identifier)
            {
            case 'g':
                {
                    activeGraph = graphs[objectIndex];
                    if(prop == "NodeCount")
                    {
                        activeGraph->_nodes.Allocate(propValue.ConvertToUnsignedInteger());
                    }
                    else if(prop == "EdgeCount")
                    {
                        activeGraph->_edges.Allocate(propValue.ConvertToUnsignedInteger());
                    }
                    else
                    {
                        SignalError("Invalid graph property");
                    }
                }
                break;
            case 'n':
                {
                    Assert(activeGraph != NULL, "Node data encountered before graph definition");
                    Node &activeNode = activeGraph->_nodes[objectIndex];
                    if(prop == "Color")
                    {
                        Vector<String> values;
                        propValue.Partition(',', values);
                        activeNode.color = RGBColor(values[0].ConvertToUnsignedInteger(), values[1].ConvertToUnsignedInteger(), values[2].ConvertToUnsignedInteger());
                    }
                    else
                    {
                        SignalError("Invalid node property");
                    }
                }
                break;
            case 'e':
                {
                    Assert(activeGraph != NULL, "Edge data encountered before graph definition");
                    Edge &activeEdge = activeGraph->_edges[objectIndex];
                    if(prop == "Label")
                    {
                        //activeEdge.label = propValue.ConvertToUnsignedInteger();
                    }
                    else if(prop == "Nodes")
                    {
                        Vector<String> values;
                        propValue.Partition(',', values);
                        Assert(values.Length() == 2, "Invalid node indices");
                        activeEdge.n0 = &activeGraph->_nodes[values[0].ConvertToUnsignedInteger()];
                        activeEdge.n1 = &activeGraph->_nodes[values[1].ConvertToUnsignedInteger()];
                    }
                    else
                    {
                        SignalError("Invalid edge property");
                    }
                }
                break;
            default:
                SignalError("Invalid identifier");
            }
        }
    }
    
    for(UINT graphIndex = 0; graphIndex < graphCount; graphIndex++)
    {
        graphs[graphIndex]->Finalize();
    }
}

void Graph::SaveToFile(const String &filename) const
{
    ofstream file(filename.CString());

    const UINT nodeCount = _nodes.Length();
    const UINT edgeCount = _edges.Length();
    file << "GraphCount 1" << endl;
    file << "g 0 NodeCount " << nodeCount << endl;
    file << "g 0 EdgeCount " << edgeCount << endl;
    for(UINT nodeIndex = 0; nodeIndex < nodeCount; nodeIndex++)
    {
        const Node &curNode = _nodes[nodeIndex];
        String nodeHeader = String("n ") + String(nodeIndex) + String(" ");
        file << nodeHeader << "Color " << UINT(curNode.color.r) << "," << UINT(curNode.color.g) << "," << UINT(curNode.color.b) << endl;
    }
    for(UINT edgeIndex = 0; edgeIndex < edgeCount; edgeIndex++)
    {
        const Edge &curEdge = _edges[edgeIndex];
        String edgeHeader = String("e ") + String(edgeIndex) + String(" ");
        file << edgeHeader << "Nodes " << curEdge.n0->index << ',' << curEdge.n1->index << endl;
        //file << edgeHeader << "Label " << curEdge.label << endl;
    }
}

void Graph::Finalize()
{
    for(UINT nodeIndex = 0; nodeIndex < _nodes.Length(); nodeIndex++)
    {
        Node &curNode = _nodes[nodeIndex];
        curNode.index = nodeIndex;
        curNode.edges.FreeMemory();
    }
    for(UINT edgeIndex = 0; edgeIndex < _edges.Length(); edgeIndex++)
    {
        Edge &curEdge = _edges[edgeIndex];
        curEdge.n0->edges.PushEnd(&curEdge);
        curEdge.n1->edges.PushEnd(&curEdge);
    }
}