aboutsummaryrefslogblamecommitdiffstats
path: root/meowpp/dsa/VP_Tree.hpp
blob: ba97ad7f8f471d03e040ff014e1abe78ce7356e7 (plain) (tree)



































































                                                                              









                                                                         




























                                                                              





































































































































































































                                                                              

#include <cstdlib>
#include <algorithm>
#include "../utility.h"

namespace meow{
  ///////////////////// **# Node #** ///////////////////////
  template<class Vector, class Scalar>
  inline
  VP_Tree<Vector, Scalar>::Node::Node(size_t __index):
  _index(__index), _nearChild(NULL), _farChild(NULL){
  }
  ///////////////////// **# Answer #** /////////////////////
  template<class Vector, class Scalar>
  inline
  VP_Tree<Vector, Scalar>::Answer::Answer(size_t __index,
                                          Scalar const& __dist2):
  _index(__index), _dist2(__dist2){
  }
  template<class Vector, class Scalar>
  inline
  VP_Tree<Vector, Scalar>::Answer::Answer(Answer const& __answer2):
  _index(__answer2._index), _dist2(__answer2._dist2){
  }
  template<class Vector, class Scalar>
  inline
  VP_Tree<Vector, Scalar>::AnswerCompare::AnswerCompare
  (Vectors const* __vectors, bool __cmpValue):
  _vectors(__vectors), _cmpValue(__cmpValue){
  }
  template<class Vector, class Scalar>
  inline bool
  VP_Tree<Vector, Scalar>::AnswerCompare::operator()(Answer const& __a,
                                                     Answer const& __b) const{
    if(__a._dist2 < __b._dist2) return true;
    if(__b._dist2 < __a._dist2) return false;
    return (_cmpValue && ((*_vectors)[__a._index] < (*_vectors)[__b._index]));
  }
  //////// **# distance2, distanceCompare, split #** ///////
  template<class Vector, class Scalar>
  inline Scalar
  VP_Tree<Vector, Scalar>::distance2(Vector const& __v1,
                                     Vector const& __v2) const{
    Scalar ret(0);
    for(size_t i = 0; i < _dimension; i++) ret += squ(__v1[i] - __v2[i]);
    return ret;
  }
  template<class Vector, class Scalar>
  inline int
  VP_Tree<Vector, Scalar>::distanceCompare(Scalar const& __a2,
                                           Scalar const& __b2,
                                           Scalar const& __c2) const{
    // test if  sqrt(__a2) +- sqrt(|__b2|) <= sqrt(__c2)
    //printf("abc = %lld %lld %lld\n", __a2, __b2, __c2);
    if(__b2 < 0){
      return -distanceCompare(__c2, -__b2, __a2);
    }
    Scalar cab(__c2 - __a2 - __b2);
    if(cab < Scalar(0)) return 1;
    Scalar ab2(Scalar(4) * __a2 * __b2), cab2(squ(cab));
    if     ( ab2 < cab2) return -1;
    else if(cab2 <  ab2) return  1;
    else                 return  0;
  }
  template<class Vector, class Scalar>
  inline Scalar
  VP_Tree<Vector, Scalar>::split(ssize_t __first, ssize_t __last,
                                 size_t __order, Vector const& __center){
    ssize_t first0 = __first;
    ssize_t last0  = __last;
    ssize_t order0 = __order;
    std::vector<Scalar> dist2(__last - __first + 1);
    for(ssize_t i = __first; i <= __last; i++){
      dist2[i - first0] = distance2(_vectors[i], __center);
    }
    while(__first < __last){
      size_t threshold_index = __first + rand() % (__last - __first + 1);
      Scalar threshold(dist2[threshold_index - first0]);
      size_t large_first = __last + 1;
      for(size_t i = __first; __first <= large_first - 1; large_first--){
        if(threshold < dist2[large_first - 1 - first0]) continue;
        while(i < large_first - 1 && !(threshold < dist2[i - first0])) i++;
        if(i < large_first - 1){
          std::swap(dist2   [large_first - 1 - first0], dist2   [i - first0]);
          std::swap(_vectors[large_first - 1         ], _vectors[i         ]);
          i++;
        }else{
          break;
        }
      }
      if(large_first == __last + 1){
        std::swap(dist2   [threshold_index-first0], dist2   [__last-first0]);
        std::swap(_vectors[threshold_index       ], _vectors[__last       ]);
        if(__order == __last - __first){
          __first = __last;
          break;
        }
        __last--;
      }else{
        if(__order < large_first - __first){
          __last = large_first - 1;
        }else{
          __order -= large_first - __first;
          __first = large_first;
        }
      }
    }
    return dist2[__first - first0];
  }
  ////////////////////// **# build() #** ///////////////////
  template<class Vector, class Scalar>
  inline typename VP_Tree<Vector, Scalar>::Node* 
  VP_Tree<Vector, Scalar>::build(ssize_t __first, ssize_t __last){
    if(__first > __last) return NULL;
    Node* ret = new Node(__first);
    if(__first < __last){
      std::swap(_vectors[__first],
                _vectors[__first + rand() % (__last - __first + 1)]);
      ssize_t mid = (__first + 1 + __last + 1) / 2;
      ret->_threshold = split(__first + 1, __last, mid - (__first + 1),
                              _vectors[__first]);
      ret->_nearChild = build(__first + 1, mid - 1        );
      ret->_farChild  = build(             mid    , __last);
    }
    return ret;
  }
  ////////////////////// **# query() #** ///////////////////
  template<class Vector, class Scalar>
  inline void
  VP_Tree<Vector, Scalar>::query(Vector const&        __vector,
                                 size_t               __k,
                                 AnswerCompare const& __cmp,
                                 Node const*          __node,
                                 Answers*             __out) const{
    if(__node == NULL) return ;
    Scalar dist2 = distance2(__vector, _vectors[__node->_index]);
    Answer my_ans(__node->_index, dist2);
    if(__out->size() < __k || __cmp(my_ans, __out->top())){
      __out->push(my_ans);
      if(__out->size() > __k){
        __out->pop();
      }
    }
    if(__node->_nearChild == NULL && __node->_farChild == NULL) return ;
    if(__out->size() < __k || distanceCompare(dist2, -__out->top()._dist2,
                                              __node->_threshold) <= 0){
      query(__vector, __k, __cmp, __node->_nearChild, __out);
    }
    if(__out->size() < __k || distanceCompare(dist2,  __out->top()._dist2,
                                              __node->_threshold) >= 0){
      query(__vector, __k, __cmp, __node->_farChild, __out);
    }
  }
  ///////////////// **# clear(), dup() #** /////////////////
  template<class Vector, class Scalar>
  inline void
  VP_Tree<Vector, Scalar>::clear(Node* __root){
    if(__root == NULL) return ;
    clear(__root->_nearChild);
    clear(__root->_farChild);
    delete __root;
  }
  template<class Vector, class Scalar>
  inline typename VP_Tree<Vector, Scalar>::Node*
  VP_Tree<Vector, Scalar>::dup(Node* __root){
    if(__root == NULL) return ;
    Node* ret = new Node(__root->_index);
    ret->_threshold = __root->_threshold;
    ret->_nearChild = dup(__root->_nearChild);
    ret->_farChild  = dup(__root->_farChild );
    return ret;
  }
  /////////////////////// **# print #** ////////////////////
  template<class Vector, class Scalar>
  inline void
  VP_Tree<Vector, Scalar>::print(Node* __node, int depth,
                                 Node* __parent, bool __near) const{
    if(__node == NULL) return ;
    printf("%*s%c)Me<%lld,%lld>, rad2 = %7lld",
           depth * 2, "",
           __near ? 'N' : 'F',
           _vectors[__node->_index][0], _vectors[__node->_index][1],
           __node->_threshold);
    if(__parent != NULL){
      printf("   ---<%lld,%lld>:: %lld\n",
             _vectors[__parent->_index][0],
             _vectors[__parent->_index][1],
             distance2(_vectors[__parent->_index],
                       _vectors[__node  ->_index]));
    }else{
      printf("\n");
    }
    print(__node->_nearChild, depth + 1, __node, true );
    print(__node->_farChild , depth + 1, __node, false);
  }
  ///////// **# construre/destructure/copy oper #** ////////
  template<class Vector, class Scalar>
  inline
  VP_Tree<Vector, Scalar>::VP_Tree():
  _root(NULL), _vectors(0), _dimension(0), _needRebuild(false){
    reset(0);
  }
  template<class Vector, class Scalar>
  inline
  VP_Tree<Vector, Scalar>::VP_Tree(VP_Tree<Vector, Scalar> const& __tree2):
  _root(dup(__tree2._root)),
  _vectors(__tree2._vectors),
  _dimension(__tree2._dimension),
  _needRebuild(__tree2._needRebuild){
  }
  template<class Vector, class Scalar>
  inline
  VP_Tree<Vector, Scalar>::VP_Tree(size_t __dimension):
  _root(NULL), _vectors(0), _dimension(0), _needRebuild(false){
    reset(__dimension);
  }
  template<class Vector, class Scalar>
  inline
  VP_Tree<Vector, Scalar>::~VP_Tree(){
    clear(_root);
  }
  template<class Vector, class Scalar>
  inline VP_Tree<Vector, Scalar>&
  VP_Tree<Vector, Scalar>::operator=(VP_Tree const& __tree2){
    reset(__tree2._dimension);
    _vectors     = __tree2._vectors;
    _root        = dup(__tree2._root);
    _needRebuild = __tree2._needRebuild;
  }
  ////////////////// **# insert, erase #** /////////////////
  template<class Vector, class Scalar>
  inline void
  VP_Tree<Vector, Scalar>::insert(Vector const& __vector){
    _vectors.push_back(__vector);
    _needRebuild = true;
  }
  template<class Vector, class Scalar>
  inline bool
  VP_Tree<Vector, Scalar>::erase(Vector const& __vector){
    for(ssize_t i = 0, I = _vectors.size(); i < I; i++){
      if(_vectors[i] == __vector){
        if(i != I - 1) std::swap(_vectors[i], _vectors[I - 1]);
        _needRebuild = true;
        _vectors.pop_back();
        return true;
      }
    }
    return false;
  }
  ////////////////// **# build, forceBuild #** /////////////
  template<class Vector, class Scalar>
  inline void
  VP_Tree<Vector, Scalar>::build(){
    if(_needRebuild){
      forceBuild();
    }
  }
  template<class Vector, class Scalar>
  inline void
  VP_Tree<Vector, Scalar>::forceBuild(){
    _root = build(0, (size_t)_vectors.size() - 1);
    _needRebuild = false;
  }
  ////////////////////// **# query #** /////////////////////
  template<class Vector, class Scalar>
  inline typename VP_Tree<Vector, Scalar>::Vectors
  VP_Tree<Vector, Scalar>::query(Vector const& __vector,
                                 size_t        __nearestNumber,
                                 bool          __compareWholeVector) const{
    ((VP_Tree*)this)->build();
    AnswerCompare cmp(&_vectors, __compareWholeVector);
    Answers answers(cmp);
    query(__vector, __nearestNumber, cmp, _root, &answers);
    std::stack<Answer> rev;
    for( ; !answers.empty(); answers.pop()) rev.push(answers.top());
    Vectors ret;
    for( ; !rev.empty(); rev.pop()) ret.push_back(_vectors[rev.top()._index]);
    return ret;
  }
  /////////////////// **# clear, reset #** /////////////////
  template<class Vector, class Scalar>
  void
  VP_Tree<Vector, Scalar>::clear(){
    clear(_root);
    _vectors.clear();
    _root = NULL;
    _needRebuild = false;
  }
  template<class Vector, class Scalar>
  size_t
  VP_Tree<Vector, Scalar>::reset(size_t __dimension){
    clear();
    _dimension = std::max((size_t)1, __dimension);
    return _dimension;
  }
  
  /////////////////////// **# print #** ////////////////////
  template<class Vector, class Scalar>
  void inline 
  VP_Tree<Vector, Scalar>::print() const{
    printf("\nsize = %lu, dimension = %lu\n", _vectors.size(), _dimension);
    print(_root, 1);
    printf("\n\n");
  }
};