diff options
author | cathook <b01902109@csie.ntu.edu.tw> | 2014-04-22 21:01:49 +0800 |
---|---|---|
committer | cathook <b01902109@csie.ntu.edu.tw> | 2014-04-22 21:01:49 +0800 |
commit | ac6d2fcb7b1a77455895fa65b42502c9b29823fd (patch) | |
tree | 1ad8e27c466a231ca076c102bb58f15eb58554fd /meowpp | |
parent | a74ca680621bd2c429dc231b34689d47d883c9c2 (diff) | |
download | meow-ac6d2fcb7b1a77455895fa65b42502c9b29823fd.tar meow-ac6d2fcb7b1a77455895fa65b42502c9b29823fd.tar.gz meow-ac6d2fcb7b1a77455895fa65b42502c9b29823fd.tar.bz2 meow-ac6d2fcb7b1a77455895fa65b42502c9b29823fd.tar.lz meow-ac6d2fcb7b1a77455895fa65b42502c9b29823fd.tar.xz meow-ac6d2fcb7b1a77455895fa65b42502c9b29823fd.tar.zst meow-ac6d2fcb7b1a77455895fa65b42502c9b29823fd.zip |
add VP_Tree
Diffstat (limited to 'meowpp')
-rw-r--r-- | meowpp/dsa/VP_Tree.h | 164 | ||||
-rw-r--r-- | meowpp/dsa/VP_Tree.hpp | 336 |
2 files changed, 500 insertions, 0 deletions
diff --git a/meowpp/dsa/VP_Tree.h b/meowpp/dsa/VP_Tree.h new file mode 100644 index 0000000..f8ab393 --- /dev/null +++ b/meowpp/dsa/VP_Tree.h @@ -0,0 +1,164 @@ +#ifndef VP_Tree_H__ +#define VP_Tree_H__ + +#include <list> +#include <vector> +#include <cstdlib> +#include <queue> +#include "../utility.h" + +namespace meow{ + //# + //#=== meow:: *VP_Tree<Vector, Scalar>* (C++ class) + //#==== Description + //# `VP_Tree` 用來維護由 *N個K維度向量所成的集合*, + //# 並可於該set中查找 *前i個離給定向量最接近的向量*. + + //# 不像 `KD_Tree` 二分樹每次都選擇一個維度去分, 分成小的跟大的, + //# `VP_Tree` 每次選一個點, 將資料分成 離這個點近的, 跟離這個點遠的. + //# 至於怎麼選呢...., 嘛還沒研究, 先random + //# + //#==== Template Class Operators Request + //#[options="header",width="70%",cols="1>m,1<,3<s,5<,3<,15<",grid="rows"] + //#|===================================================================== + //#|Const?|Typename| Operator | Parameters | Return_Type| Description + //#|const | Vector|operator[] |(size_t `n`) | Scalar | 取得第 `n` 維度量 + //#|const | Vector|operator= |(Vector `v`) | Vector& | copy operator + //#|const | Vector|operator< |(Vector `v`) | bool | 權重比較 + //#|const | Scalar| 'Scalar' |(int `n`) | Scalar | 建構子, + //# 其中一定`n=0 or 4` + //#|const | Scalar|operator* |(Scalar `s`) | Scalar | 相乘 + //#|const | Scalar|operator+ |(Scalar `s`) | Scalar | 相加 + //#|const | Scalar|operator- |(Scalar `s`) | Scalar | 相差 + //#|const | Scalar|operator- |( ) | Scalar | 取負號 + //#|const | Scalar|operator< |(Scalar `s`) | bool | 大小比較 + //#|===================================================================== + //# + template<class Vector, class Scalar> + class VP_Tree{ + public: + //#==== Custom Type Definitions + //# * `Vectors` <- `std::vector<Vector>` + //# + typedef typename std::vector<Vector> Vectors; + private: + // + struct Node{ + size_t _index; + Scalar _threshold; + Node* _nearChild; + Node* _farChild; + Node(size_t __index); + }; + struct Answer{ + size_t _index; + Scalar _dist2; + // + Answer(size_t __index, Scalar const& __dist2); + Answer(Answer const& __answer2); + }; + class AnswerCompare{ + private: + Vectors const* _vectors; + bool _cmpValue; + public: + AnswerCompare(Vectors const* __vectors, bool __cmpValue); + bool operator()(Answer const& __a, Answer const& __b) const; + }; + typedef std::vector<Answer> AnswerV; + typedef std::priority_queue<Answer, AnswerV, AnswerCompare> Answers; + // + Vectors _vectors; + Node* _root; + size_t _dimension; + bool _needRebuild; + // + Scalar distance2(Vector const& __v1, Vector const& __v2) const; + int distanceCompare(Scalar const& __a2, Scalar const& __b2, + Scalar const& __c2) const; + Scalar split(ssize_t __first, ssize_t __last, size_t __order, + Vector const& __center); + // + Node* build(ssize_t __first, ssize_t __last); + void query(Vector const& __vector, + size_t __k, + AnswerCompare const& __cmp, + Node const* __node, + Answers* __out) const; + void clear(Node* __root); + Node* dup(Node* __root); + // + void print(Node* __node, int depth = 1, + Node* __parent = NULL, bool __near = true) const; + public: + VP_Tree(); + VP_Tree(VP_Tree const& __tree2); + VP_Tree(size_t __dimension); + ~VP_Tree(); + VP_Tree& operator=(VP_Tree const& __tree2); + + //#==== Support Methods + //# + //# * N <- `this` 中擁有的資料數 + //# * D <- `this` 資料維度 + //# + //#[options="header",width="100%",cols="1>m,3>s,7<,3<,3^,20<",grid="rows"] + //#|===================================================================== + //#|Const?|Name | Parameters | Return_Type| Time_Complexity| Description + + + //#||insert|(Vector const& `v`)|void| O(1) + //#|將向量 `v` 加到set中 + void insert(Vector const& __vector); + + + //#||erase|(Vector const& `v`)|bool| O(N) + //#|將向量 `v` 從set中移除, '~TODO:可以再優化~' + bool erase (Vector const& __vector); + + + //#||build|()|void|O(KN logN) or O(1) + //#|檢查距上一次 `build()` 至今是否有 `insert/erase` 被呼叫, + //# 若有, 重新建樹, 否則不做事 + void build(); + + + //#||forceBuild|()|void|O(KN logN) + //#|重新建樹 + void forceBuild(); + + + //#|const|query|(Vector const& `v`,\size_t `i`,\bool `cmp`)|Vectors + //#|O(KN ^1-1/K^ ) + //#|於set中找尋距離 `v` 前 `i` 近的向量, 並依照由近而遠的順序排序. + //# 如果有兩個向量 `v1`,`v2` 距離一樣, 且 `cmp` 為 `true` , 則直接依照 + //# `v1 < v2` 來決定誰在前面. 最後回傳一陣列包含所有解. + Vectors query(Vector const& __vector, + size_t __nearestNumber, + bool __compareWholeVector) const; + + + //#||clear|()|void|O(1) + //#|清空所有資料 + void clear(); + + + //#||reset|(size_t `dimension`)|size_t|O(1) + //#|清空所有資料並且指定維度為 `max(1, dimension)` 並且回傳指定後的維度 + size_t reset(size_t __dimension); + + + //#|===================================================================== + + void print() const; + }; + //# + //#[NOTE] + //#======================================== + //#======================================== + //# + //# ''' +} + +#include "VP_Tree.hpp" + +#endif // VP_Tree_H__ diff --git a/meowpp/dsa/VP_Tree.hpp b/meowpp/dsa/VP_Tree.hpp new file mode 100644 index 0000000..a3a0d82 --- /dev/null +++ b/meowpp/dsa/VP_Tree.hpp @@ -0,0 +1,336 @@ + +#include <cstdlib> +#include <algorithm> +#include "../utility.h" + +namespace meow{ + ///////////////////// **# Node #** /////////////////////// + template<class Vector, class Scalar> + inline + VP_Tree<Vector, Scalar>::Node::Node(size_t __index): + _index(__index), _nearChild(NULL), _farChild(NULL){ + } + ///////////////////// **# Answer #** ///////////////////// + template<class Vector, class Scalar> + inline + VP_Tree<Vector, Scalar>::Answer::Answer(size_t __index, + Scalar const& __dist2): + _index(__index), _dist2(__dist2){ + } + template<class Vector, class Scalar> + inline + VP_Tree<Vector, Scalar>::Answer::Answer(Answer const& __answer2): + _index(__answer2._index), _dist2(__answer2._dist2){ + } + template<class Vector, class Scalar> + inline + VP_Tree<Vector, Scalar>::AnswerCompare::AnswerCompare + (Vectors const* __vectors, bool __cmpValue): + _vectors(__vectors), _cmpValue(__cmpValue){ + } + template<class Vector, class Scalar> + inline bool + VP_Tree<Vector, Scalar>::AnswerCompare::operator()(Answer const& __a, + Answer const& __b) const{ + if(__a._dist2 < __b._dist2) return true; + if(__b._dist2 < __a._dist2) return false; + return (_cmpValue && ((*_vectors)[__a._index] < (*_vectors)[__b._index])); + } + //////// **# distance2, distanceCompare, split #** /////// + template<class Vector, class Scalar> + inline Scalar + VP_Tree<Vector, Scalar>::distance2(Vector const& __v1, + Vector const& __v2) const{ + Scalar ret(0); + for(size_t i = 0; i < _dimension; i++) ret += squ(__v1[i] - __v2[i]); + return ret; + } + template<class Vector, class Scalar> + inline int + VP_Tree<Vector, Scalar>::distanceCompare(Scalar const& __a2, + Scalar const& __b2, + Scalar const& __c2) const{ + // test if sqrt(__a2) +- sqrt(|__b2|) <= sqrt(__c2) + //printf("abc = %lld %lld %lld\n", __a2, __b2, __c2); + if(__b2 < 0){ + return -distanceCompare(__c2, -__b2, __a2); + } + Scalar cab(__c2 - __a2 - __b2); + if(cab < Scalar(0)) return 1; + Scalar ab2(Scalar(4) * __a2 * __b2), cab2(squ(cab)); + if ( ab2 < cab2) return -1; + else if(cab2 < ab2) return 1; + else return 0; + } + template<class Vector, class Scalar> + inline Scalar + VP_Tree<Vector, Scalar>::split(ssize_t __first, ssize_t __last, + size_t __order, Vector const& __center){ + //printf("%ld %ld %lu\n", __first, __last, __order); + ssize_t first0 = __first; + ssize_t last0 = __last; + ssize_t order0 = __order; + std::vector<Scalar> dist2(__last - __first + 1); + for(ssize_t i = __first; i <= __last; i++){ + dist2[i - first0] = distance2(_vectors[i], __center); + } + while(__first < __last){ + size_t threshold_index = __first + rand() % (__last - __first + 1); + Scalar threshold(dist2[threshold_index - first0]); + /* + printf("range(%ld, %ld) dist2 = %3lld from %d\n", + __first - first0, __last - first0, + threshold, threshold_index - first0); + for(int i = first0; i <= last0; i++){ + if(i == __first) printf("+"); + if(i == threshold_index) printf("<"); + printf("<%lld,%lld,(%lld)>", _vectors[i][0], _vectors[i][1], + dist2[i - first0]); + if(i == threshold_index) printf(">"); + if(i == __last) printf("+"); + printf(" "); + } + printf("\n"); + // */ + size_t large_first = __last + 1; + for(size_t i = __first; __first <= large_first - 1; large_first--){ + if(threshold < dist2[large_first - 1 - first0]) continue; + while(i < large_first - 1 && !(threshold < dist2[i - first0])) i++; + if(i < large_first - 1){ + std::swap(dist2 [large_first - 1 - first0], dist2 [i - first0]); + std::swap(_vectors[large_first - 1 ], _vectors[i ]); + i++; + }else{ + break; + } + } + if(large_first == __last + 1){ + std::swap(dist2 [threshold_index-first0], dist2 [__last-first0]); + std::swap(_vectors[threshold_index ], _vectors[__last ]); + if(__order == __last - __first){ + __first = __last; + break; + } + __last--; + }else{ + if(__order < large_first - __first){ + __last = large_first - 1; + }else{ + __order -= large_first - __first; + __first = large_first; + } + } + } + /* + for(int i = first0; i <= last0; i++){ + if(i == __first) printf("+"); + if(i - first0 == order0) printf("<"); + printf("<%lld,%lld,(%lld)>", _vectors[i][0], _vectors[i][1], + dist2[i - first0]); + if(i - first0 == order0) printf(">"); + if(i == __first) printf("+"); + printf(" "); + } + printf("\n"); + printf("dist2(from<%lld,%lld>) = %lld\n", + __center[0], __center[1], + dist2[__first - first0]); + // */ + return dist2[__first - first0]; + } + ////////////////////// **# build() #** /////////////////// + template<class Vector, class Scalar> + inline typename VP_Tree<Vector, Scalar>::Node* + VP_Tree<Vector, Scalar>::build(ssize_t __first, ssize_t __last){ + if(__first > __last) return NULL; + Node* ret = new Node(__first); + if(__first < __last){ + std::swap(_vectors[__first], + _vectors[__first + rand() % (__last - __first + 1)]); + ssize_t mid = (__first + 1 + __last + 1) / 2; + ret->_threshold = split(__first + 1, __last, mid - (__first + 1), + _vectors[__first]); + ret->_nearChild = build(__first + 1, mid - 1 ); + ret->_farChild = build( mid , __last); + } + return ret; + } + ////////////////////// **# query() #** /////////////////// + template<class Vector, class Scalar> + inline void + VP_Tree<Vector, Scalar>::query(Vector const& __vector, + size_t __k, + AnswerCompare const& __cmp, + Node const* __node, + Answers* __out) const{ + if(__node == NULL) return ; + Scalar dist2 = distance2(__vector, _vectors[__node->_index]); + Answer my_ans(__node->_index, dist2); + if(__out->size() < __k || __cmp(my_ans, __out->top())){ + __out->push(my_ans); + if(__out->size() > __k){ + __out->pop(); + } + } + if(__node->_nearChild == NULL && __node->_farChild == NULL) return ; + if(__out->size() < __k || distanceCompare(dist2, -__out->top()._dist2, + __node->_threshold) <= 0){ + query(__vector, __k, __cmp, __node->_nearChild, __out); + } + if(__out->size() < __k || distanceCompare(dist2, __out->top()._dist2, + __node->_threshold) >= 0){ + query(__vector, __k, __cmp, __node->_farChild, __out); + } + } + ///////////////// **# clear(), dup() #** ///////////////// + template<class Vector, class Scalar> + inline void + VP_Tree<Vector, Scalar>::clear(Node* __root){ + if(__root == NULL) return ; + clear(__root->_nearChild); + clear(__root->_farChild); + delete __root; + } + template<class Vector, class Scalar> + inline typename VP_Tree<Vector, Scalar>::Node* + VP_Tree<Vector, Scalar>::dup(Node* __root){ + if(__root == NULL) return ; + Node* ret = new Node(__root->_index); + ret->_threshold = __root->_threshold; + ret->_nearChild = dup(__root->_nearChild); + ret->_farChild = dup(__root->_farChild ); + return ret; + } + /////////////////////// **# print #** //////////////////// + template<class Vector, class Scalar> + inline void + VP_Tree<Vector, Scalar>::print(Node* __node, int depth, + Node* __parent, bool __near) const{ + if(__node == NULL) return ; + printf("%*s%c)Me<%lld,%lld>, rad2 = %7lld", + depth * 2, "", + __near ? 'N' : 'F', + _vectors[__node->_index][0], _vectors[__node->_index][1], + __node->_threshold); + if(__parent != NULL){ + printf(" ---<%lld,%lld>:: %lld\n", + _vectors[__parent->_index][0], + _vectors[__parent->_index][1], + distance2(_vectors[__parent->_index], + _vectors[__node ->_index])); + }else{ + printf("\n"); + } + print(__node->_nearChild, depth + 1, __node, true ); + print(__node->_farChild , depth + 1, __node, false); + } + ///////// **# construre/destructure/copy oper #** //////// + template<class Vector, class Scalar> + inline + VP_Tree<Vector, Scalar>::VP_Tree(): + _root(NULL), _vectors(0), _dimension(0), _needRebuild(false){ + reset(0); + } + template<class Vector, class Scalar> + inline + VP_Tree<Vector, Scalar>::VP_Tree(VP_Tree<Vector, Scalar> const& __tree2): + _root(dup(__tree2._root)), + _vectors(__tree2._vectors), + _dimension(__tree2._dimension), + _needRebuild(__tree2._needRebuild){ + } + template<class Vector, class Scalar> + inline + VP_Tree<Vector, Scalar>::VP_Tree(size_t __dimension): + _root(NULL), _vectors(0), _dimension(0), _needRebuild(false){ + reset(__dimension); + } + template<class Vector, class Scalar> + inline + VP_Tree<Vector, Scalar>::~VP_Tree(){ + clear(_root); + } + template<class Vector, class Scalar> + inline VP_Tree<Vector, Scalar>& + VP_Tree<Vector, Scalar>::operator=(VP_Tree const& __tree2){ + reset(__tree2._dimension); + _vectors = __tree2._vectors; + _root = dup(__tree2._root); + _needRebuild = __tree2._needRebuild; + } + ////////////////// **# insert, erase #** ///////////////// + template<class Vector, class Scalar> + inline void + VP_Tree<Vector, Scalar>::insert(Vector const& __vector){ + _vectors.push_back(__vector); + _needRebuild = true; + } + template<class Vector, class Scalar> + inline bool + VP_Tree<Vector, Scalar>::erase(Vector const& __vector){ + for(ssize_t i = 0, I = _vectors.size(); i < I; i++){ + if(_vectors[i] == __vector){ + if(i != I - 1) std::swap(_vectors[i], _vectors[I - 1]); + _needRebuild = true; + _vectors.pop_back(); + return true; + } + } + return false; + } + ////////////////// **# build, forceBuild #** ///////////// + template<class Vector, class Scalar> + inline void + VP_Tree<Vector, Scalar>::build(){ + if(_needRebuild){ + forceBuild(); + } + } + template<class Vector, class Scalar> + inline void + VP_Tree<Vector, Scalar>::forceBuild(){ + _root = build(0, (size_t)_vectors.size() - 1); + _needRebuild = false; + } + ////////////////////// **# query #** ///////////////////// + template<class Vector, class Scalar> + inline typename VP_Tree<Vector, Scalar>::Vectors + VP_Tree<Vector, Scalar>::query(Vector const& __vector, + size_t __nearestNumber, + bool __compareWholeVector) const{ + ((VP_Tree*)this)->build(); + AnswerCompare cmp(&_vectors, __compareWholeVector); + Answers answers(cmp); + query(__vector, __nearestNumber, cmp, _root, &answers); + std::stack<Answer> rev; + for( ; !answers.empty(); answers.pop()) rev.push(answers.top()); + Vectors ret; + for( ; !rev.empty(); rev.pop()) ret.push_back(_vectors[rev.top()._index]); + return ret; + } + /////////////////// **# clear, reset #** ///////////////// + template<class Vector, class Scalar> + void + VP_Tree<Vector, Scalar>::clear(){ + clear(_root); + _vectors.clear(); + _root = NULL; + _needRebuild = false; + } + template<class Vector, class Scalar> + size_t + VP_Tree<Vector, Scalar>::reset(size_t __dimension){ + clear(); + _dimension = std::max((size_t)1, __dimension); + return _dimension; + } + + /////////////////////// **# print #** //////////////////// + template<class Vector, class Scalar> + void inline + VP_Tree<Vector, Scalar>::print() const{ + printf("\nsize = %lu, dimension = %lu\n", _vectors.size(), _dimension); + print(_root, 1); + printf("\n\n"); + } +}; |