#include "VP_Tree.h" #include #include #include #include "../math/utility.h" namespace meow{ ///////////////////// **# Node #** /////////////////////// template inline VP_Tree::Node::Node(size_t __index): _index(__index), _nearChild(NULL), _farChild(NULL){ } ///////////////////// **# Answer #** ///////////////////// template inline VP_Tree::Answer::Answer(size_t __index, Scalar const& __dist2): _index(__index), _dist2(__dist2){ } template inline VP_Tree::Answer::Answer(Answer const& __answer2): _index(__answer2._index), _dist2(__answer2._dist2){ } template inline VP_Tree::AnswerCompare::AnswerCompare (Vectors const* __vectors, bool __cmpValue): _vectors(__vectors), _cmpValue(__cmpValue){ } template inline bool VP_Tree::AnswerCompare::operator()(Answer const& __a, Answer const& __b) const{ if(__a._dist2 < __b._dist2) return true; if(__b._dist2 < __a._dist2) return false; return (_cmpValue && ((*_vectors)[__a._index] < (*_vectors)[__b._index])); } //////// **# distance2, distanceCompare, split #** /////// template inline Scalar VP_Tree::distance2(Vector const& __v1, Vector const& __v2) const{ Scalar ret(0); for(size_t i = 0; i < _dimension; i++) ret += squ(__v1[i] - __v2[i]); return ret; } template inline int VP_Tree::distanceCompare(Scalar const& __a2, Scalar const& __b2, Scalar const& __c2) const{ // test if sqrt(__a2) +- sqrt(|__b2|) <= sqrt(__c2) if(__b2 < 0){ return -distanceCompare(__c2, -__b2, __a2); } Scalar cab(__c2 - __a2 - __b2); if(cab < Scalar(0)) return 1; Scalar ab2(Scalar(4) * __a2 * __b2), cab2(squ(cab)); if ( ab2 < cab2) return -1; else if(cab2 < ab2) return 1; else return 0; } template inline Scalar VP_Tree::split(ssize_t __first, ssize_t __last, size_t __order, Vector const& __center){ ssize_t first0 = __first; ssize_t last0 = __last; ssize_t order0 = __order; std::vector dist2(__last - __first + 1); for(ssize_t i = __first; i <= __last; i++){ dist2[i - first0] = distance2(_vectors[i], __center); } while(__first < __last){ size_t threshold_index = __first + rand() % (__last - __first + 1); Scalar threshold(dist2[threshold_index - first0]); size_t large_first = __last + 1; for(size_t i = __first; __first <= large_first - 1; large_first--){ if(threshold < dist2[large_first - 1 - first0]) continue; while(i < large_first - 1 && !(threshold < dist2[i - first0])) i++; if(i < large_first - 1){ std::swap(dist2 [large_first - 1 - first0], dist2 [i - first0]); std::swap(_vectors[large_first - 1 ], _vectors[i ]); i++; }else{ break; } } if(large_first == __last + 1){ std::swap(dist2 [threshold_index-first0], dist2 [__last-first0]); std::swap(_vectors[threshold_index ], _vectors[__last ]); if(__order == __last - __first){ __first = __last; break; } __last--; }else{ if(__order < large_first - __first){ __last = large_first - 1; }else{ __order -= large_first - __first; __first = large_first; } } } return dist2[__first - first0]; } ////////////////////// **# build() #** /////////////////// template inline typename VP_Tree::Node* VP_Tree::build(ssize_t __first, ssize_t __last){ if(__first > __last) return NULL; Node* ret = new Node(__first); if(__first < __last){ std::swap(_vectors[__first], _vectors[__first + rand() % (__last - __first + 1)]); ssize_t mid = (__first + 1 + __last + 1) / 2; ret->_threshold = split(__first + 1, __last, mid - (__first + 1), _vectors[__first]); ret->_nearChild = build(__first + 1, mid - 1 ); ret->_farChild = build( mid , __last); } return ret; } ////////////////////// **# query() #** /////////////////// template inline void VP_Tree::query(Vector const& __vector, size_t __k, AnswerCompare const& __cmp, Node const* __node, Answers* __out) const{ if(__node == NULL) return ; Scalar dist2 = distance2(__vector, _vectors[__node->_index]); Answer my_ans(__node->_index, dist2); if(__out->size() < __k || __cmp(my_ans, __out->top())){ __out->push(my_ans); if(__out->size() > __k){ __out->pop(); } } if(__node->_nearChild == NULL && __node->_farChild == NULL) return ; if(__out->size() < __k || distanceCompare(dist2, -__out->top()._dist2, __node->_threshold) <= 0){ query(__vector, __k, __cmp, __node->_nearChild, __out); } if(__out->size() < __k || distanceCompare(dist2, __out->top()._dist2, __node->_threshold) >= 0){ query(__vector, __k, __cmp, __node->_farChild, __out); } } ///////////////// **# clear(), dup() #** ///////////////// template inline void VP_Tree::clear(Node* __root){ if(__root == NULL) return ; clear(__root->_nearChild); clear(__root->_farChild); delete __root; } template inline typename VP_Tree::Node* VP_Tree::dup(Node* __root){ if(__root == NULL) return ; Node* ret = new Node(__root->_index); ret->_threshold = __root->_threshold; ret->_nearChild = dup(__root->_nearChild); ret->_farChild = dup(__root->_farChild ); return ret; } ///////// **# construre/destructure/copy oper #** //////// template inline VP_Tree::VP_Tree(): _root(NULL), _vectors(0), _dimension(0), _needRebuild(false){ reset(0); } template inline VP_Tree::VP_Tree(VP_Tree const& __tree2): _root(dup(__tree2._root)), _vectors(__tree2._vectors), _dimension(__tree2._dimension), _needRebuild(__tree2._needRebuild){ } template inline VP_Tree::VP_Tree(size_t __dimension): _root(NULL), _vectors(0), _dimension(0), _needRebuild(false){ reset(__dimension); } template inline VP_Tree::~VP_Tree(){ clear(_root); } template inline VP_Tree& VP_Tree::operator=(VP_Tree const& __tree2){ reset(__tree2._dimension); _vectors = __tree2._vectors; _root = dup(__tree2._root); _needRebuild = __tree2._needRebuild; } ////////////////// **# insert, erase #** ///////////////// template inline void VP_Tree::insert(Vector const& __vector){ _vectors.push_back(__vector); _needRebuild = true; } template inline bool VP_Tree::erase(Vector const& __vector){ for(ssize_t i = 0, I = _vectors.size(); i < I; i++){ if(_vectors[i] == __vector){ if(i != I - 1) std::swap(_vectors[i], _vectors[I - 1]); _needRebuild = true; _vectors.pop_back(); return true; } } return false; } ////////////////// **# build, forceBuild #** ///////////// template inline void VP_Tree::build(){ if(_needRebuild){ forceBuild(); } } template inline void VP_Tree::forceBuild(){ _root = build(0, (size_t)_vectors.size() - 1); _needRebuild = false; } ////////////////////// **# query #** ///////////////////// template inline typename VP_Tree::Vectors VP_Tree::query(Vector const& __vector, size_t __nearestNumber, bool __compareWholeVector) const{ ((VP_Tree*)this)->build(); AnswerCompare cmp(&_vectors, __compareWholeVector); Answers answers(cmp); query(__vector, __nearestNumber, cmp, _root, &answers); std::stack rev; for( ; !answers.empty(); answers.pop()) rev.push(answers.top()); Vectors ret; for( ; !rev.empty(); rev.pop()) ret.push_back(_vectors[rev.top()._index]); return ret; } /////////////////// **# clear, reset #** ///////////////// template void VP_Tree::clear(){ clear(_root); _vectors.clear(); _root = NULL; _needRebuild = false; } template size_t VP_Tree::reset(size_t __dimension){ clear(); _dimension = std::max((size_t)1, __dimension); return _dimension; } };