#include "Main.h"

void WalkKernelSimple::SetParameters(const NodeKernel &nodeKernel, const EdgeKernel &edgeKernel)
{
    _nKernel = &nodeKernel;
    _eKernel = &edgeKernel;
}

double WalkKernelSimple::Evaluate(const Walk &w0, const Walk &w1) const
{
    UINT length = w0.nodes().Length();
    if(length != w1.nodes().Length())
        return 0.0;
    
    double result = 1.0;
    for (UINT visitIndex = 0; visitIndex < length; visitIndex++)
    {
        const Node &v0 = *w0.nodes()[visitIndex];
        const Node &v1 = *w1.nodes()[visitIndex];
        result *= _nKernel->Evaluate(v0, v1);
        if (visitIndex != (length - 1))
        {
            const Node &v0Next = *w0.nodes()[visitIndex + 1];
            const Node &v1Next = *w1.nodes()[visitIndex + 1];
            result *= _eKernel->Evaluate(*(v0.FindEdgeTo(&v0Next)), *(v1.FindEdgeTo(&v1Next)));
        }
    }

    return result;

}