#include "KD_Tree.h" #include "../utility.h" #include "../math/utility.h" #include #include #include #include namespace meow{ //////////////////////////////////////////////////////////////////// // **# Node #** // //////////////////////////////////////////////////////////////////// template inline KD_Tree::Node::Node(Vector __vector, ssize_t __lChild, ssize_t __rChild): _vector(__vector), _lChild(__lChild), _rChild(__rChild){ } //////////////////////////////////////////////////////////////////// // **# Sorter #** // //////////////////////////////////////////////////////////////////// template inline KD_Tree::Sorter::Sorter(Nodes const* __nodes, size_t __cmp): _nodes(__nodes), _cmp(__cmp){ } template inline bool KD_Tree::Sorter::operator()(size_t const& __a, size_t const& __b) const{ if((*_nodes)[__a]._vector[_cmp] != (*_nodes)[__b]._vector[_cmp]){ return ((*_nodes)[__a]._vector[_cmp] < (*_nodes)[__b]._vector[_cmp]); } return ((*_nodes)[__a]._vector < (*_nodes)[__b]._vector); } //////////////////////////////////////////////////////////////////// // **# Answer / Answer's Compare class #** // //////////////////////////////////////////////////////////////////// template inline KD_Tree::Answer::Answer(ssize_t __index, Scalar __dist2): _index(__index), _dist2(__dist2){ } template inline KD_Tree::Answer::Answer(Answer const& __answer2): _index(__answer2._index), _dist2(__answer2._dist2){ } // template inline KD_Tree::AnswerCompare::AnswerCompare(Nodes const* __nodes, bool __cmpValue): _nodes(__nodes), _cmpValue(__cmpValue){ } template inline bool KD_Tree::AnswerCompare::operator()(Answer const& __a, Answer const& __b) const{ if(_cmpValue == true && __a._dist2 == __b._dist2){ return ((*_nodes)[__a._index]._vector < (*_nodes)[__b._index]._vector); } return (__a._dist2 < __b._dist2); } //////////////////////////////////////////////////////////////////// // **# distance2() #** // //////////////////////////////////////////////////////////////////// template inline Scalar KD_Tree::distance2(Vector const& __v1, Vector const& __v2) const{ Scalar ret(0); for(size_t i = 0; i < _dimension; i++){ ret += squ(__v1[i] - __v2[i]); } return ret; } //////////////////////////////////////////////////////////////////// // **# query() #** // //////////////////////////////////////////////////////////////////// template inline void KD_Tree::query(Vector const& __vector, size_t __nearestNumber, AnswerCompare const& __answerCompare, ssize_t __index, int __depth, std::vector& __dist2Vector, Scalar __dist2Minimum, Answers* __out) const{ if(__index == _NIL) return ; size_t cmp = __depth % _dimension; ssize_t this_side, that_side; if(!(_nodes[__index]._vector[cmp] < __vector[cmp])){ this_side = _nodes[__index]._lChild; that_side = _nodes[__index]._rChild; }else{ this_side = _nodes[__index]._rChild; that_side = _nodes[__index]._lChild; } query(__vector, __nearestNumber, __answerCompare, this_side, __depth + 1, __dist2Vector, __dist2Minimum, __out); Answer my_ans(__index, distance2(_nodes[__index]._vector, __vector)); if(__out->size() < __nearestNumber || __answerCompare(my_ans, __out->top())){ __out->push(my_ans); if(__out->size() > __nearestNumber) __out->pop(); } Scalar dist2_old = __dist2Vector[cmp]; __dist2Vector[cmp] = squ(_nodes[__index]._vector[cmp] - __vector[cmp]); Scalar dist2Minimum = __dist2Minimum + __dist2Vector[cmp] - dist2_old; if(__out->size() < __nearestNumber || !(__out->top()._dist2 < dist2Minimum)){ query(__vector, __nearestNumber, __answerCompare, that_side, __depth + 1, __dist2Vector, dist2Minimum, __out); } __dist2Vector[cmp] = dist2_old; } //////////////////////////////////////////////////////////////////// // **# build() #** // //////////////////////////////////////////////////////////////////// template inline ssize_t KD_Tree::build(ssize_t __beg, ssize_t __end, std::vector* __orders, int __depth){ if(__beg > __end) return _NIL; size_t tmp_order = _dimension; size_t which_side = _dimension + 1; ssize_t mid = (__beg + __end) / 2; size_t cmp = __depth % _dimension; for(ssize_t i = __beg; i <= mid; i++){ __orders[which_side][__orders[cmp][i]] = 0; } for(ssize_t i = mid + 1; i <= __end; i++){ __orders[which_side][__orders[cmp][i]] = 1; } for(size_t i = 0; i < _dimension; i++){ if(i == cmp) continue; size_t left = __beg, right = mid + 1; for(int j = __beg; j <= __end; j++){ size_t ask = __orders[i][j]; if(ask == __orders[cmp][mid]){ __orders[tmp_order][mid] = ask; }else if(__orders[which_side][ask] == 1){ __orders[tmp_order][right++] = ask; }else{ __orders[tmp_order][left++] = ask; } } for(int j = __beg; j <= __end; j++){ __orders[i][j] = __orders[tmp_order][j]; } } _nodes[__orders[cmp][mid]]._lChild=build(__beg,mid-1,__orders,__depth+1); _nodes[__orders[cmp][mid]]._rChild=build(mid+1,__end,__orders,__depth+1); return __orders[cmp][mid]; } //////////////////////////////////////////////////////////////////// // **# constructures/destructures #** // //////////////////////////////////////////////////////////////////// template inline KD_Tree::KD_Tree(): _NIL(-1), _root(_NIL), _needRebuild(false), _dimension(1){ } template inline KD_Tree::KD_Tree(size_t __dimension): _NIL(-1), _root(_NIL), _needRebuild(false), _dimension(__dimension){ } template inline KD_Tree::~KD_Tree(){ } //////////////////////////////////////////////////////////////////// // **# insert, build #** // //////////////////////////////////////////////////////////////////// template inline void KD_Tree::insert(Vector const& __vector){ _nodes.push_back(Node(__vector, _NIL, _NIL)); _needRebuild = true; } template inline bool KD_Tree::erase(Vector const& __vector){ for(size_t i = 0, I = _nodes.size(); i < I; i++){ if(_nodes[i] == __vector){ if(i != I - 1){ std::swap(_nodes[i], _nodes[I - 1]); } _needRebuild = true; return true; } } return false; } template inline void KD_Tree::build(){ if(_needRebuild){ forceBuild(); } } template inline void KD_Tree::forceBuild(){ std::vector *orders = new std::vector[_dimension + 2]; for(size_t j = 0; j < _dimension + 2; j++){ orders[j].resize(_nodes.size()); } for(size_t j = 0; j < _dimension; j++){ for(size_t i = 0, I = _nodes.size(); i < I; i++){ orders[j][i] = i; } std::sort(orders[j].begin(), orders[j].end(), Sorter(&_nodes, j)); } _root = build(0, (ssize_t)_nodes.size() - 1, orders, 0); delete [] orders; _needRebuild = false; } //////////////////////////////////////////////////////////////////// // **# query #** // //////////////////////////////////////////////////////////////////// template inline typename KD_Tree::Vectors KD_Tree::query(Vector const& __vector, size_t __nearestNumber, bool __compareWholeVector) const{ ((KD_Tree*)this)->build(); AnswerCompare answer_compare(&_nodes, __compareWholeVector); Answers answer_set(answer_compare); std::vector tmp(_dimension, 0); query(__vector, __nearestNumber, answer_compare, _root, 0, tmp, Scalar(0), &answer_set); Vectors ret(answer_set.size()); for(int i = (ssize_t)answer_set.size() - 1; i >= 0; i--){ ret[i] = _nodes[answer_set.top()._index]._vector; answer_set.pop(); } return ret; } //////////////////////////////////////////////////////////////////// // **# clear, reset #** // //////////////////////////////////////////////////////////////////// template inline void KD_Tree::clear(){ _root = _NIL; _nodes.clear(); _needRebuild = false; } template inline void KD_Tree::reset(size_t __dimension){ clear(); _dimension = __dimension; } }