#include "Main.h"

void RootedGraphKernelDynamicProgramming::SetParameters(UINT pathLength, UINT maxPathLength, const NodeKernel &nodeKernel, const EdgeKernel &edgeKernel)
{
    _pathLength = pathLength;
    _maxPathLength = maxPathLength;
    _nodeKernel = &nodeKernel;
    _edgeKernel = &edgeKernel;
    _tableDirty = true;
}

void RootedGraphKernelDynamicProgramming::SetGraphs(const Graph &g0, const Graph &g1)
{
    _g0 = &g0;
    _g1 = &g1;
}

double RootedGraphKernelDynamicProgramming::Evaluate(const Node &n0, const Node &n1) const
{
    if(_tableDirty)
    {
        UpdateTable();
    }
    return GetTableValue(n0.index, n1.index, _pathLength);
}

void RootedGraphKernelDynamicProgramming::UpdateTable() const
{
    if(!_tableDirty)
    {
        return;
    }

    const UINT n0Count = _g0->nodes().Length();
    const UINT n1Count = _g1->nodes().Length();
    _table.Allocate(n0Count, n1Count);
    for(UINT n0Index = 0; n0Index < n0Count; n0Index++)
    {
        for(UINT n1Index = 0; n1Index < n1Count; n1Index++)
        {
            Vector<double> &curEntry = _table(n0Index, n1Index);
            curEntry.Allocate(_maxPathLength + 1, -1.0);

            //
            // Base Case
            //
            curEntry[0] = _nodeKernel->Evaluate(_g0->nodes()[n0Index], _g1->nodes()[n1Index]);
        }
    }

    

    for(UINT pathIndex = 1; pathIndex <= _maxPathLength; pathIndex++)
    {
        for(UINT n0Index = 0; n0Index < n0Count; n0Index++)
        {
            for(UINT n1Index = 0; n1Index < n1Count; n1Index++)
            {
                Vector<double> &curEntry = _table(n0Index, n1Index);
                
                const Node &n0 = _g0->nodes()[n0Index];
                const Node &n1 = _g1->nodes()[n1Index];

                double sum = 0.0;
                for(auto n0NeighborIterator = n0.edges.Begin(); n0NeighborIterator != n0.edges.End(); n0NeighborIterator++)
                {
                    const Node &n0Neighbor = *((*n0NeighborIterator)->GetOtherNode(&n0));
                    for(auto n1NeighborIterator = n1.edges.Begin(); n1NeighborIterator != n1.edges.End(); n1NeighborIterator++)
                    {
                        const Node &n1Neighbor = *((*n1NeighborIterator)->GetOtherNode(&n1));

                        //
                        // TODO: consider caching edge kernel evaluations
                        //
                        double edgeTerm = _edgeKernel->Evaluate(**n0NeighborIterator, **n1NeighborIterator);

                        sum += GetTableValue(n0Neighbor.index, n1Neighbor.index, pathIndex - 1) * edgeTerm;
                    }
                }
                
                curEntry[pathIndex] = curEntry[0] * sum;
            }
        }
    }

    _tableDirty = false;
}