#include <cstdlib>
#include <algorithm>
#include "../utility.h"
namespace meow{
///////////////////// **# Node #** ///////////////////////
template<class Vector, class Scalar>
inline
VP_Tree<Vector, Scalar>::Node::Node(size_t __index):
_index(__index), _nearChild(NULL), _farChild(NULL){
}
///////////////////// **# Answer #** /////////////////////
template<class Vector, class Scalar>
inline
VP_Tree<Vector, Scalar>::Answer::Answer(size_t __index,
Scalar const& __dist2):
_index(__index), _dist2(__dist2){
}
template<class Vector, class Scalar>
inline
VP_Tree<Vector, Scalar>::Answer::Answer(Answer const& __answer2):
_index(__answer2._index), _dist2(__answer2._dist2){
}
template<class Vector, class Scalar>
inline
VP_Tree<Vector, Scalar>::AnswerCompare::AnswerCompare
(Vectors const* __vectors, bool __cmpValue):
_vectors(__vectors), _cmpValue(__cmpValue){
}
template<class Vector, class Scalar>
inline bool
VP_Tree<Vector, Scalar>::AnswerCompare::operator()(Answer const& __a,
Answer const& __b) const{
if(__a._dist2 < __b._dist2) return true;
if(__b._dist2 < __a._dist2) return false;
return (_cmpValue && ((*_vectors)[__a._index] < (*_vectors)[__b._index]));
}
//////// **# distance2, distanceCompare, split #** ///////
template<class Vector, class Scalar>
inline Scalar
VP_Tree<Vector, Scalar>::distance2(Vector const& __v1,
Vector const& __v2) const{
Scalar ret(0);
for(size_t i = 0; i < _dimension; i++) ret += squ(__v1[i] - __v2[i]);
return ret;
}
template<class Vector, class Scalar>
inline int
VP_Tree<Vector, Scalar>::distanceCompare(Scalar const& __a2,
Scalar const& __b2,
Scalar const& __c2) const{
// test if sqrt(__a2) +- sqrt(|__b2|) <= sqrt(__c2)
//printf("abc = %lld %lld %lld\n", __a2, __b2, __c2);
if(__b2 < 0){
return -distanceCompare(__c2, -__b2, __a2);
}
Scalar cab(__c2 - __a2 - __b2);
if(cab < Scalar(0)) return 1;
Scalar ab2(Scalar(4) * __a2 * __b2), cab2(squ(cab));
if ( ab2 < cab2) return -1;
else if(cab2 < ab2) return 1;
else return 0;
}
template<class Vector, class Scalar>
inline Scalar
VP_Tree<Vector, Scalar>::split(ssize_t __first, ssize_t __last,
size_t __order, Vector const& __center){
ssize_t first0 = __first;
ssize_t last0 = __last;
ssize_t order0 = __order;
std::vector<Scalar> dist2(__last - __first + 1);
for(ssize_t i = __first; i <= __last; i++){
dist2[i - first0] = distance2(_vectors[i], __center);
}
while(__first < __last){
size_t threshold_index = __first + rand() % (__last - __first + 1);
Scalar threshold(dist2[threshold_index - first0]);
size_t large_first = __last + 1;
for(size_t i = __first; __first <= large_first - 1; large_first--){
if(threshold < dist2[large_first - 1 - first0]) continue;
while(i < large_first - 1 && !(threshold < dist2[i - first0])) i++;
if(i < large_first - 1){
std::swap(dist2 [large_first - 1 - first0], dist2 [i - first0]);
std::swap(_vectors[large_first - 1 ], _vectors[i ]);
i++;
}else{
break;
}
}
if(large_first == __last + 1){
std::swap(dist2 [threshold_index-first0], dist2 [__last-first0]);
std::swap(_vectors[threshold_index ], _vectors[__last ]);
if(__order == __last - __first){
__first = __last;
break;
}
__last--;
}else{
if(__order < large_first - __first){
__last = large_first - 1;
}else{
__order -= large_first - __first;
__first = large_first;
}
}
}
return dist2[__first - first0];
}
////////////////////// **# build() #** ///////////////////
template<class Vector, class Scalar>
inline typename VP_Tree<Vector, Scalar>::Node*
VP_Tree<Vector, Scalar>::build(ssize_t __first, ssize_t __last){
if(__first > __last) return NULL;
Node* ret = new Node(__first);
if(__first < __last){
std::swap(_vectors[__first],
_vectors[__first + rand() % (__last - __first + 1)]);
ssize_t mid = (__first + 1 + __last + 1) / 2;
ret->_threshold = split(__first + 1, __last, mid - (__first + 1),
_vectors[__first]);
ret->_nearChild = build(__first + 1, mid - 1 );
ret->_farChild = build( mid , __last);
}
return ret;
}
////////////////////// **# query() #** ///////////////////
template<class Vector, class Scalar>
inline void
VP_Tree<Vector, Scalar>::query(Vector const& __vector,
size_t __k,
AnswerCompare const& __cmp,
Node const* __node,
Answers* __out) const{
if(__node == NULL) return ;
Scalar dist2 = distance2(__vector, _vectors[__node->_index]);
Answer my_ans(__node->_index, dist2);
if(__out->size() < __k || __cmp(my_ans, __out->top())){
__out->push(my_ans);
if(__out->size() > __k){
__out->pop();
}
}
if(__node->_nearChild == NULL && __node->_farChild == NULL) return ;
if(__out->size() < __k || distanceCompare(dist2, -__out->top()._dist2,
__node->_threshold) <= 0){
query(__vector, __k, __cmp, __node->_nearChild, __out);
}
if(__out->size() < __k || distanceCompare(dist2, __out->top()._dist2,
__node->_threshold) >= 0){
query(__vector, __k, __cmp, __node->_farChild, __out);
}
}
///////////////// **# clear(), dup() #** /////////////////
template<class Vector, class Scalar>
inline void
VP_Tree<Vector, Scalar>::clear(Node* __root){
if(__root == NULL) return ;
clear(__root->_nearChild);
clear(__root->_farChild);
delete __root;
}
template<class Vector, class Scalar>
inline typename VP_Tree<Vector, Scalar>::Node*
VP_Tree<Vector, Scalar>::dup(Node* __root){
if(__root == NULL) return ;
Node* ret = new Node(__root->_index);
ret->_threshold = __root->_threshold;
ret->_nearChild = dup(__root->_nearChild);
ret->_farChild = dup(__root->_farChild );
return ret;
}
/////////////////////// **# print #** ////////////////////
template<class Vector, class Scalar>
inline void
VP_Tree<Vector, Scalar>::print(Node* __node, int depth,
Node* __parent, bool __near) const{
if(__node == NULL) return ;
printf("%*s%c)Me<%lld,%lld>, rad2 = %7lld",
depth * 2, "",
__near ? 'N' : 'F',
_vectors[__node->_index][0], _vectors[__node->_index][1],
__node->_threshold);
if(__parent != NULL){
printf(" ---<%lld,%lld>:: %lld\n",
_vectors[__parent->_index][0],
_vectors[__parent->_index][1],
distance2(_vectors[__parent->_index],
_vectors[__node ->_index]));
}else{
printf("\n");
}
print(__node->_nearChild, depth + 1, __node, true );
print(__node->_farChild , depth + 1, __node, false);
}
///////// **# construre/destructure/copy oper #** ////////
template<class Vector, class Scalar>
inline
VP_Tree<Vector, Scalar>::VP_Tree():
_root(NULL), _vectors(0), _dimension(0), _needRebuild(false){
reset(0);
}
template<class Vector, class Scalar>
inline
VP_Tree<Vector, Scalar>::VP_Tree(VP_Tree<Vector, Scalar> const& __tree2):
_root(dup(__tree2._root)),
_vectors(__tree2._vectors),
_dimension(__tree2._dimension),
_needRebuild(__tree2._needRebuild){
}
template<class Vector, class Scalar>
inline
VP_Tree<Vector, Scalar>::VP_Tree(size_t __dimension):
_root(NULL), _vectors(0), _dimension(0), _needRebuild(false){
reset(__dimension);
}
template<class Vector, class Scalar>
inline
VP_Tree<Vector, Scalar>::~VP_Tree(){
clear(_root);
}
template<class Vector, class Scalar>
inline VP_Tree<Vector, Scalar>&
VP_Tree<Vector, Scalar>::operator=(VP_Tree const& __tree2){
reset(__tree2._dimension);
_vectors = __tree2._vectors;
_root = dup(__tree2._root);
_needRebuild = __tree2._needRebuild;
}
////////////////// **# insert, erase #** /////////////////
template<class Vector, class Scalar>
inline void
VP_Tree<Vector, Scalar>::insert(Vector const& __vector){
_vectors.push_back(__vector);
_needRebuild = true;
}
template<class Vector, class Scalar>
inline bool
VP_Tree<Vector, Scalar>::erase(Vector const& __vector){
for(ssize_t i = 0, I = _vectors.size(); i < I; i++){
if(_vectors[i] == __vector){
if(i != I - 1) std::swap(_vectors[i], _vectors[I - 1]);
_needRebuild = true;
_vectors.pop_back();
return true;
}
}
return false;
}
////////////////// **# build, forceBuild #** /////////////
template<class Vector, class Scalar>
inline void
VP_Tree<Vector, Scalar>::build(){
if(_needRebuild){
forceBuild();
}
}
template<class Vector, class Scalar>
inline void
VP_Tree<Vector, Scalar>::forceBuild(){
_root = build(0, (size_t)_vectors.size() - 1);
_needRebuild = false;
}
////////////////////// **# query #** /////////////////////
template<class Vector, class Scalar>
inline typename VP_Tree<Vector, Scalar>::Vectors
VP_Tree<Vector, Scalar>::query(Vector const& __vector,
size_t __nearestNumber,
bool __compareWholeVector) const{
((VP_Tree*)this)->build();
AnswerCompare cmp(&_vectors, __compareWholeVector);
Answers answers(cmp);
query(__vector, __nearestNumber, cmp, _root, &answers);
std::stack<Answer> rev;
for( ; !answers.empty(); answers.pop()) rev.push(answers.top());
Vectors ret;
for( ; !rev.empty(); rev.pop()) ret.push_back(_vectors[rev.top()._index]);
return ret;
}
/////////////////// **# clear, reset #** /////////////////
template<class Vector, class Scalar>
void
VP_Tree<Vector, Scalar>::clear(){
clear(_root);
_vectors.clear();
_root = NULL;
_needRebuild = false;
}
template<class Vector, class Scalar>
size_t
VP_Tree<Vector, Scalar>::reset(size_t __dimension){
clear();
_dimension = std::max((size_t)1, __dimension);
return _dimension;
}
/////////////////////// **# print #** ////////////////////
template<class Vector, class Scalar>
void inline
VP_Tree<Vector, Scalar>::print() const{
printf("\nsize = %lu, dimension = %lu\n", _vectors.size(), _dimension);
print(_root, 1);
printf("\n\n");
}
};