#include "Main.h"

void RootedGraphKernelEnumerateWalks::SetParameters(UINT pathLength, const WalkKernel &walkKernel)
{
    _pathLength = pathLength;
    _walkKernel = &walkKernel;
    
}

void RootedGraphKernelEnumerateWalks::SetGraphs(const Graph &g0, const Graph &g1)
{
    _g0 = &g0;
    _g1 = &g1;
}

double RootedGraphKernelEnumerateWalks::Evaluate(const Node &n0, const Node &n1) const
{
    Vector<Walk*> g0Walks, g1Walks;
    _g0->EnumerateAllWalks(_pathLength, n0, g0Walks);
    _g1->EnumerateAllWalks(_pathLength, n1, g1Walks);

    double result = 0.0;
    for (UINT g0WalkIndex = 0; g0WalkIndex < g0Walks.Length(); g0WalkIndex++)
    {
        const Walk &g0Walk = *(g0Walks[g0WalkIndex]);
        
        for (UINT g1WalkIndex = 0; g1WalkIndex < g1Walks.Length(); g1WalkIndex++)
        {
            const Walk &g1Walk = *(g1Walks[g1WalkIndex]);

            result += _walkKernel->Evaluate(g0Walk, g1Walk);
        }
    }
    g0Walks.DeleteMemory();
    g1Walks.DeleteMemory();

    return result;
}