#include "Main.h"
void RootedGraphKernelDynamicProgramming::SetParameters(UINT pathLength, UINT maxPathLength, const NodeKernel &nodeKernel, const EdgeKernel &edgeKernel)
{
_pathLength = pathLength;
_maxPathLength = maxPathLength;
_nodeKernel = &nodeKernel;
_edgeKernel = &edgeKernel;
_tableDirty = true;
}
void RootedGraphKernelDynamicProgramming::SetGraphs(const Graph &g0, const Graph &g1)
{
_g0 = &g0;
_g1 = &g1;
}
double RootedGraphKernelDynamicProgramming::Evaluate(const Node &n0, const Node &n1) const
{
if(_tableDirty)
{
UpdateTable();
}
return GetTableValue(n0.index, n1.index, _pathLength);
}
void RootedGraphKernelDynamicProgramming::UpdateTable() const
{
if(!_tableDirty)
{
return;
}
const UINT n0Count = _g0->nodes().Length();
const UINT n1Count = _g1->nodes().Length();
_table.Allocate(n0Count, n1Count);
for(UINT n0Index = 0; n0Index < n0Count; n0Index++)
{
for(UINT n1Index = 0; n1Index < n1Count; n1Index++)
{
Vector<double> &curEntry = _table(n0Index, n1Index);
curEntry.Allocate(_maxPathLength + 1, -1.0);
curEntry[0] = _nodeKernel->Evaluate(_g0->nodes()[n0Index], _g1->nodes()[n1Index]);
}
}
for(UINT pathIndex = 1; pathIndex <= _maxPathLength; pathIndex++)
{
for(UINT n0Index = 0; n0Index < n0Count; n0Index++)
{
for(UINT n1Index = 0; n1Index < n1Count; n1Index++)
{
Vector<double> &curEntry = _table(n0Index, n1Index);
const Node &n0 = _g0->nodes()[n0Index];
const Node &n1 = _g1->nodes()[n1Index];
double sum = 0.0;
for(auto n0NeighborIterator = n0.edges.Begin(); n0NeighborIterator != n0.edges.End(); n0NeighborIterator++)
{
const Node &n0Neighbor = *((*n0NeighborIterator)->GetOtherNode(&n0));
for(auto n1NeighborIterator = n1.edges.Begin(); n1NeighborIterator != n1.edges.End(); n1NeighborIterator++)
{
const Node &n1Neighbor = *((*n1NeighborIterator)->GetOtherNode(&n1));
double edgeTerm = _edgeKernel->Evaluate(**n0NeighborIterator, **n1NeighborIterator);
sum += GetTableValue(n0Neighbor.index, n1Neighbor.index, pathIndex - 1) * edgeTerm;
}
}
curEntry[pathIndex] = curEntry[0] * sum;
}
}
}
_tableDirty = false;
}