template struct PairwiseCouplingConfiguration { PairwiseCouplingConfiguration() { Factory = NULL; } PairwiseCouplingConfiguration(BinaryClassifierFactory *_Factory) { Factory = _Factory; } BinaryClassifierFactory *Factory; }; template class MulticlassClassifierPairwiseCoupling : public MulticlassClassifier { public: MulticlassClassifierPairwiseCoupling() { _Configured = false; } ~MulticlassClassifierPairwiseCoupling() { FreeMemory(); } MulticlassClassifierType Type() const { return MulticlassClassifierTypePairwiseCoupling; } void FreeMemory() { for(UINT RowIndex = 0; RowIndex < _Classifiers.Rows(); RowIndex++) { for(UINT ColIndex = 0; ColIndex < _Classifiers.Cols(); ColIndex++) { BinaryClassifier *CurClassifier = _Classifiers(RowIndex, ColIndex); if(CurClassifier != NULL) { delete CurClassifier; } } } _Classifiers.FreeMemory(); } MulticlassClassifierPairwiseCoupling(const PairwiseCouplingConfiguration &Config) { _Configured = true; _Config = Config; } void Configure(const PairwiseCouplingConfiguration &Config) { _Configured = true; _Config = Config; } void Train(const Dataset &Examples) { FreeMemory(); Console::WriteLine(String("Training multiclass pairwise coupling classifier, ") + String(Examples.Entries().Length()) + String(" examples")); PersistentAssert(_Configured, "Classifier not configured"); const UINT ExampleCount = Examples.Entries().Length(); const UINT ClassCount = Examples.ClassCount(); _ClassCount = ClassCount; _Classifiers.Allocate(Examples.ClassCount(), Examples.ClassCount()); _Classifiers.Clear(NULL); for(UINT RowIndex = 0; RowIndex < ClassCount; RowIndex++) { for(UINT ColIndex = RowIndex + 1; ColIndex < ClassCount; ColIndex++) { _Classifiers(RowIndex, ColIndex) = _Config.Factory->MakeClassifier(); BinaryClassifier &CurClassifier = *(_Classifiers(RowIndex, ColIndex)); ClassifierDataset LocalDataset; LocalDataset.SubclassFromDataset(Examples, RowIndex, ColIndex); Console::WriteLine(String("Training ") + String(RowIndex) + String("(") + String(LocalDataset.CountExamplesOfClass(RowIndex)) + String(" samples) vs. ") + String(ColIndex) + String("(") + String(LocalDataset.CountExamplesOfClass(ColIndex)) + String(" samples)")); CurClassifier.Train(LocalDataset, RowIndex, ColIndex); } } } void Evaluate(const LearnerInput &Input, UINT &Class, Vector &ClassProbabilities) const { Assert(_Configured, "Classifier not configured"); if(ClassProbabilities.Length() != _ClassCount) { ClassProbabilities.Allocate(_ClassCount); } const UINT ClassCount = _Classifiers.Rows(); Vector ClassVotes(ClassCount); Vector ClassProbabilitySum(ClassCount); ClassVotes.Clear(0); ClassProbabilitySum.Clear(0.0); for(UINT RowIndex = 0; RowIndex < ClassCount; RowIndex++) { for(UINT ColIndex = RowIndex + 1; ColIndex < ClassCount; ColIndex++) { BinaryClassifier &CurClassifier = *(_Classifiers(RowIndex, ColIndex)); double LocalProbabilityClass0; UINT LocalClass; CurClassifier.Evaluate(Input, LocalClass, LocalProbabilityClass0); ClassProbabilitySum[RowIndex] += LocalProbabilityClass0; ClassProbabilitySum[ColIndex] += 1.0 - LocalProbabilityClass0; if(LocalClass == RowIndex) { ClassVotes[RowIndex]++; } else { ClassVotes[ColIndex]++; } } } UINT MaxVotesIndex = 0; UINT MaxVotes = ClassVotes[0]; if(ClassProbabilities.Length() != ClassCount) { ClassProbabilities.Allocate(ClassCount); } for(UINT ClassIndex = 0; ClassIndex < ClassCount; ClassIndex++) { if(ClassVotes[ClassIndex] > MaxVotes) { MaxVotes = ClassVotes[ClassIndex]; MaxVotesIndex = ClassIndex; } //ClassProbabilities[ClassIndex] = double(ClassVotes[ClassIndex]) / double(ClassCount - 1); ClassProbabilities[ClassIndex] = ClassProbabilitySum[ClassIndex] / double(ClassCount - 1); } Class = MaxVotesIndex; } void SaveToBinaryStream(OutputDataStream &Stream) const { SignalError("Not implemented"); } void LoadFromBinaryStream(InputDataStream &Stream) { SignalError("Not implemented"); } private: bool _Configured; PairwiseCouplingConfiguration _Config; Grid*> _Classifiers; UINT _ClassCount; }; template class MulticlassClassifierFactoryPairwiseCoupling : public MulticlassClassifierFactory { public: MulticlassClassifierFactoryPairwiseCoupling(const PairwiseCouplingConfiguration &Config) { _Config = Config; } MulticlassClassifier* MakeClassifier() const { MulticlassClassifier *Result = new MulticlassClassifierPairwiseCoupling(_Config); return Result; } private: PairwiseCouplingConfiguration _Config; };