class Vec3fKMeansMetric
{
public:
static __forceinline float Dist(const Vec3f &L, const Vec3f &R)
{
return Vec3f::DistSq(L, R);
}
};
class Vec2fKMeansMetric
{
public:
static __forceinline float Dist(const Vec2f &L, const Vec2f &R)
{
return Vec2f::DistSq(L, R);
}
};
template<class T>
class KMeansCluster
{
public:
void Init(const T &Start)
{
_Center = Start;
}
void FinishIteration(const T &FallbackElement)
{
if(_Entries.Length() == 0)
{
_Center = FallbackElement;
}
else
{
float Sum = _Entries[0].second;
T NewCenter = _Entries[0].first;
for(UINT EntryIndex = 1; EntryIndex < _Entries.Length(); EntryIndex++)
{
Sum += _Entries[EntryIndex].second;
NewCenter += _Entries[EntryIndex].first;
}
_Center = NewCenter;
if(Sum > 0.0f)
{
_Center *= (1.0f / Sum);
}
_Entries.FreeMemory();
}
}
__forceinline const T& Center() const
{
return _Center;
}
__forceinline void AddEntry(const T &PreWeightedEntry, float Weight)
{
_Entries.PushEnd( pair<T, float>(PreWeightedEntry, Weight) );
}
__forceinline void AddEntry(const T &Entry)
{
_Entries.PushEnd( pair<T, float>(Entry, 1.0f) );
}
private:
T _Center;
Vector< pair<T, float> > _Entries;
};
template<class T, class Metric>
class KMeansClustering
{
public:
void Cluster(const Vector<T> &Elements, UINT ClusterCount, UINT MaxIterations = 0, bool Verbose = true)
{
if(Verbose)
{
Console::WriteLine(String("K-means clustering, ") + String(Elements.Length()) + String (" points, ") + String(ClusterCount) + String(" clusters"));
Console::AdvanceLine();
}
PersistentAssert(Elements.Length() >= ClusterCount, "Too many clusters");
_Clusters.Allocate(ClusterCount);
for(UINT ClusterIndex = 0; ClusterIndex < ClusterCount; ClusterIndex++)
{
_Clusters[ClusterIndex].Init(Elements.RandomElement());
}
UINT PassIndex = 0;
Vector<T> PreviousClusterCenters(_Clusters.Length());
bool Converged = false;
while(!Converged)
{
PassIndex++;
for(UINT ClusterIndex = 0; ClusterIndex < ClusterCount; ClusterIndex++)
{
PreviousClusterCenters[ClusterIndex] = _Clusters[ClusterIndex].Center();
}
Iterate(Elements);
double Delta = 0.0;
for(UINT ClusterIndex = 0; ClusterIndex < ClusterCount; ClusterIndex++)
{
Delta += Metric::Dist(PreviousClusterCenters[ClusterIndex], _Clusters[ClusterIndex].Center());
}
Converged = (Delta == 0.0 || PassIndex == MaxIterations);
if(Verbose)
{
Console::OverwriteLine(String("Pass ") + String(PassIndex) + String(", ") + String("Delta=") + String(Delta));
}
}
}
void Cluster(const Vector<T> &Elements, const Vector<float> &Weights, UINT ClusterCount, UINT MaxIterations = 0, bool Verbose = true)
{
if(Verbose)
{
Console::WriteLine(String("Weighted K-means clustering, ") + String(Elements.Length()) + String (" points, ") + String(ClusterCount) + String(" clusters"));
Console::AdvanceLine();
}
PersistentAssert(Elements.Length() >= ClusterCount, "Too many clusters");
PersistentAssert(Elements.Length() == Weights.Length(), "Incorrect number of weights");
_Clusters.Allocate(ClusterCount);
for(UINT ClusterIndex = 0; ClusterIndex < ClusterCount; ClusterIndex++)
{
_Clusters[ClusterIndex].Init(Elements.RandomElement());
}
Vector<T> WeightedElements = Elements;
for(UINT ElementIndex = 0; ElementIndex < Elements.Length(); ElementIndex++)
{
WeightedElements[ElementIndex] *= Weights[ElementIndex];
}
UINT PassIndex = 0;
Vector<T> PreviousClusterCenters(_Clusters.Length());
bool Converged = false;
while(!Converged)
{
PassIndex++;
for(UINT ClusterIndex = 0; ClusterIndex < ClusterCount; ClusterIndex++)
{
PreviousClusterCenters[ClusterIndex] = _Clusters[ClusterIndex].Center();
}
Iterate(Elements, WeightedElements, Weights);
double Delta = 0.0;
for(UINT ClusterIndex = 0; ClusterIndex < ClusterCount; ClusterIndex++)
{
Delta += Metric::Dist(PreviousClusterCenters[ClusterIndex], _Clusters[ClusterIndex].Center());
}
Converged = (Delta == 0.0 || PassIndex == MaxIterations);
if(Verbose)
{
Console::OverwriteLine(String("Pass ") + String(PassIndex) + String(", ") + String("Delta=") + String(Delta));
}
}
}
__forceinline const T& ClusterCenter(UINT Index)
{
return _Clusters[Index].Center();
}
__forceinline const T& QuantizeToNearestClusterCenter(const T &Element)
{
return _Clusters[QuantizeToNearestClusterCenter(Element)].Center();
}
__forceinline UINT QuantizeToNearestClusterIndex(const T &Element)
{
UINT ClosestClusterIndex = 0;
double ClosestClusterDist = Metric::Dist(Element, _Clusters[0].Center());
for(UINT ClusterIndex = 1; ClusterIndex < _Clusters.Length(); ClusterIndex++)
{
double CurClusterDist = Metric::Dist(Element, _Clusters[ClusterIndex].Center());
if(CurClusterDist < ClosestClusterDist)
{
ClosestClusterIndex = ClusterIndex;
ClosestClusterDist = CurClusterDist;
}
}
return ClosestClusterIndex;
}
__forceinline UINT ClusterCount()
{
return _Clusters.Length();
}
private:
void Iterate(const Vector<T> &Elements)
{
const UINT ElementCount = Elements.Length();
const T* ElementCArray = Elements.CArray();
const UINT ClusterCount = _Clusters.Length();
KMeansCluster<T> *ClusterCArray = _Clusters.CArray();
for(UINT ElementIndex = 0; ElementIndex < ElementCount; ElementIndex++)
{
const T& CurElement = ElementCArray[ElementIndex];
UINT ClosestClusterIndex = 0;
double ClosestClusterDist = Metric::Dist(CurElement, ClusterCArray[0].Center());
for(UINT ClusterIndex = 1; ClusterIndex < ClusterCount; ClusterIndex++)
{
double CurClusterDist = Metric::Dist(CurElement, ClusterCArray[ClusterIndex].Center());
if(CurClusterDist < ClosestClusterDist)
{
ClosestClusterIndex = ClusterIndex;
ClosestClusterDist = CurClusterDist;
}
}
ClusterCArray[ClosestClusterIndex].AddEntry(CurElement);
}
for(UINT ClusterIndex = 0; ClusterIndex < ClusterCount; ClusterIndex++)
{
ClusterCArray[ClusterIndex].FinishIteration(Elements.RandomElement());
}
}
void Iterate(const Vector<T> &Elements, const Vector<T> &WeightedElements, const Vector<float> &Weights)
{
const UINT ElementCount = Elements.Length();
const T* ElementCArray = Elements.CArray();
const float* WeightsCArray = Weights.CArray();
const UINT ClusterCount = _Clusters.Length();
KMeansCluster<T> *ClusterCArray = _Clusters.CArray();
for(UINT ElementIndex = 0; ElementIndex < ElementCount; ElementIndex++)
{
const T& CurElement = ElementCArray[ElementIndex];
float CurWeight = WeightsCArray[ElementIndex];
UINT ClosestClusterIndex = 0;
double ClosestClusterDist = Metric::Dist(CurElement, ClusterCArray[0].Center());
for(UINT ClusterIndex = 1; ClusterIndex < ClusterCount; ClusterIndex++)
{
double CurClusterDist = Metric::Dist(CurElement, ClusterCArray[ClusterIndex].Center());
if(CurClusterDist < ClosestClusterDist)
{
ClosestClusterIndex = ClusterIndex;
ClosestClusterDist = CurClusterDist;
}
}
ClusterCArray[ClosestClusterIndex].AddEntry(WeightedElements[ElementIndex], CurWeight);
}
for(UINT ClusterIndex = 0; ClusterIndex < ClusterCount; ClusterIndex++)
{
ClusterCArray[ClusterIndex].FinishIteration(Elements.RandomElement());
}
}
Vector< KMeansCluster<T> > _Clusters;
};