#include "Main.h" void RootedGraphKernelDynamicProgramming::SetParameters(UINT pathLength, UINT maxPathLength, const NodeKernel &nodeKernel, const EdgeKernel &edgeKernel) { _pathLength = pathLength; _maxPathLength = maxPathLength; _nodeKernel = &nodeKernel; _edgeKernel = &edgeKernel; _tableDirty = true; } void RootedGraphKernelDynamicProgramming::SetGraphs(const Graph &g0, const Graph &g1) { _g0 = &g0; _g1 = &g1; } double RootedGraphKernelDynamicProgramming::Evaluate(const Node &n0, const Node &n1) const { if(_tableDirty) { UpdateTable(); } return GetTableValue(n0.index, n1.index, _pathLength); } void RootedGraphKernelDynamicProgramming::UpdateTable() const { if(!_tableDirty) { return; } const UINT n0Count = _g0->nodes().Length(); const UINT n1Count = _g1->nodes().Length(); _table.Allocate(n0Count, n1Count); for(UINT n0Index = 0; n0Index < n0Count; n0Index++) { for(UINT n1Index = 0; n1Index < n1Count; n1Index++) { Vector &curEntry = _table(n0Index, n1Index); curEntry.Allocate(_maxPathLength + 1, -1.0); // // Base Case // curEntry[0] = _nodeKernel->Evaluate(_g0->nodes()[n0Index], _g1->nodes()[n1Index]); } } for(UINT pathIndex = 1; pathIndex <= _maxPathLength; pathIndex++) { for(UINT n0Index = 0; n0Index < n0Count; n0Index++) { for(UINT n1Index = 0; n1Index < n1Count; n1Index++) { Vector &curEntry = _table(n0Index, n1Index); const Node &n0 = _g0->nodes()[n0Index]; const Node &n1 = _g1->nodes()[n1Index]; double sum = 0.0; for(auto n0NeighborIterator = n0.edges.Begin(); n0NeighborIterator != n0.edges.End(); n0NeighborIterator++) { const Node &n0Neighbor = *((*n0NeighborIterator)->GetOtherNode(&n0)); for(auto n1NeighborIterator = n1.edges.Begin(); n1NeighborIterator != n1.edges.End(); n1NeighborIterator++) { const Node &n1Neighbor = *((*n1NeighborIterator)->GetOtherNode(&n1)); // // TODO: consider caching edge kernel evaluations // double edgeTerm = _edgeKernel->Evaluate(**n0NeighborIterator, **n1NeighborIterator); sum += GetTableValue(n0Neighbor.index, n1Neighbor.index, pathIndex - 1) * edgeTerm; } } curEntry[pathIndex] = curEntry[0] * sum; } } } _tableDirty = false; }