From ac6d2fcb7b1a77455895fa65b42502c9b29823fd Mon Sep 17 00:00:00 2001 From: cathook Date: Tue, 22 Apr 2014 21:01:49 +0800 Subject: add VP_Tree --- _test/meowpp_VP_Tree.cpp | 180 +++++++++++++++++++++++++ meowpp/dsa/VP_Tree.h | 164 +++++++++++++++++++++++ meowpp/dsa/VP_Tree.hpp | 336 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 680 insertions(+) create mode 100644 _test/meowpp_VP_Tree.cpp create mode 100644 meowpp/dsa/VP_Tree.h create mode 100644 meowpp/dsa/VP_Tree.hpp diff --git a/_test/meowpp_VP_Tree.cpp b/_test/meowpp_VP_Tree.cpp new file mode 100644 index 0000000..8d5e903 --- /dev/null +++ b/_test/meowpp_VP_Tree.cpp @@ -0,0 +1,180 @@ +#include "meowpp.h" + +#include + +#include +#include +#include +#include + +#include + +static int N = 100000; +static int D = 64; +static int MAX = 100; + +typedef long long lnt; + +struct MyVector{ + std::vector v; + int w; + // + MyVector(MyVector const& _v): v(_v.v), w(_v.w){ } + MyVector( ):v(D){ for(int i = 0; i < D; i++){ v[i] = (lnt)rand() % MAX; } } + MyVector(lnt k):v(D){ for(int i = 0; i < D; i++){ v[i] = k; } } + // + lnt & operator[](size_t n) { return v[n]; } + lnt const& operator[](size_t n) const{ return v[n]; } + bool operator<(MyVector const& v2) const{ return (w < v2.w); } + bool operator==(MyVector const& v2) const{ + for(int i = 0; i < D; i++) if(v[i] != v2[i]) return false; + return (w == v2.w); + } +}; + + +static lnt dist2(MyVector const& v1, MyVector const& v2){ + lnt k = 0; + for(int i = 0; i < D; i++){ + k += (v1[i] - v2[i]) * (v1[i] - v2[i]); + } + return k; +} + +static std::vector data; + +void show(MyVector const& v, std::vector const& r1, std::vector const& r2){ + if(N <= 20 && r1.size() <= 7){ + printf("\n"); + for(int i = 0; i < N; i++){ + printf("%3d) ", data[i].w); + for(int j = 0; j < D; j++) + printf("%8lld ", data[i][j]); + printf(" ===> %lld\n", dist2(data[i], v)); + } + printf("\n"); + printf("ask) "); + for(int j = 0; j < D; j++) + printf("%8lld ", v[j]); + printf("\n"); + printf("---------\n"); + for(int i = 0; i < r1.size(); i++){ + printf("%3d) ", r1[i].w); + for(int j = 0; j < D; j++) + printf("%8lld ", r1[i][j]); + printf(" ===> %lld\n", dist2(r1[i], v)); + } + printf("---------\n"); + for(int i = 0; i < r2.size(); i++){ + printf("%3d) ", r2[i].w); + for(int j = 0; j < D; j++) + printf("%8lld ", r2[i][j]); + printf(" ===> %lld\n", dist2(r2[i], v)); + } + } +} + +namespace VP{ + struct Answer{ + int i; + lnt d; + // + Answer(int _i, lnt _d): i(_i), d(_d){ } + Answer(Answer const& _a): i(_a.i), d(_a.d){ } + // + bool operator<(Answer const& b) const{ + if(d != b.d) return (d < b.d); + else return (data[i] < data[b.i]); + } + }; +} + +static std::vector find(MyVector const& v, int k){ + std::priority_queue qu; + for(int i = 0; i < std::min(k, N); i++){ + qu.push(VP::Answer(i, dist2(v, data[i]))); + } + for(int i = std::min(k, N); i < N; i++){ + qu.push(VP::Answer(i, dist2(v, data[i]))); + qu.pop(); + } + std::vector ret(qu.size()); + for(int i = (ssize_t)qu.size() - 1; i >= 0; i--){ + ret[i] = data[qu.top().i]; + qu.pop(); + } + return ret; +} + +TEST(VP_Tree){ + int t0, t1, t2; + + meow::VP_Tree tree(D); + + meow::messagePrintf(1, "Create data (N = %d, D = %d)", N, D); + data.resize(N); + for(int i = 0; i < N; i++){ + if(i <= N / 10) + data[i] = MyVector((lnt)i); + else{ + for(int j = 0; j < D; j++){ + data[i][j] = rand() % MAX; + } + } + } + for(int i = 0; i < N; i++){ + data[i].w = i; + } + for(int i = 0; i < N; i++){ + tree.insert(data[i]); + } + meow::messagePrintf(-1, "ok"); + meow::messagePrintf(1, "build"); + t0 = clock(); + tree.build(); + //tree.print(); + meow::messagePrintf(-1, "ok, %.3f seconds", (clock() - t0) * 1.0 / CLOCKS_PER_SEC); + + meow::messagePrintf(1, "query..."); + meow::KD_Tree::Vectors ret1, ret2; + for(int k = 1; k <= std::min(100, N); k++){ + meow::messagePrintf(1, "range k = %d", k); + t1 = t2 = 0; + for(int i = 0; i < 10; i++){ + MyVector ask; + + t0 = clock(); + tree.build(); + ret1 = tree.query(ask, k, true); + t1 += clock() - t0; + + t0 = clock(); + ret2 = find(ask, k); + t2 += clock() - t0; + + if(ret1.size() != ret2.size() && false){ + meow::messagePrintf(-1, "(%d)query fail, size error", i); + meow::messagePrintf(-1, "fail"); + return false; + } + for(int kk = 0, KK = ret1.size(); kk < KK; kk++){ + if(ret1[kk] == ret2[kk]){ + continue; + } + show(ask, ret1, ret2); + meow::messagePrintf(-1, "(%d)query fail", i); + meow::messagePrintf(-1, "fail"); + return false; + } + } + meow::messagePrintf(-1, "ok %.3f/%.3f", + t1 * 1.0 / CLOCKS_PER_SEC, + t2 * 1.0 / CLOCKS_PER_SEC + ); + } + meow::messagePrintf(-1, "ok"); + + + return true; +}; +; diff --git a/meowpp/dsa/VP_Tree.h b/meowpp/dsa/VP_Tree.h new file mode 100644 index 0000000..f8ab393 --- /dev/null +++ b/meowpp/dsa/VP_Tree.h @@ -0,0 +1,164 @@ +#ifndef VP_Tree_H__ +#define VP_Tree_H__ + +#include +#include +#include +#include +#include "../utility.h" + +namespace meow{ + //# + //#=== meow:: *VP_Tree* (C++ class) + //#==== Description + //# `VP_Tree` 用來維護由 *N個K維度向量所成的集合*, + //# 並可於該set中查找 *前i個離給定向量最接近的向量*. + + //# 不像 `KD_Tree` 二分樹每次都選擇一個維度去分, 分成小的跟大的, + //# `VP_Tree` 每次選一個點, 將資料分成 離這個點近的, 跟離這個點遠的. + //# 至於怎麼選呢...., 嘛還沒研究, 先random + //# + //#==== Template Class Operators Request + //#[options="header",width="70%",cols="1>m,1<,3 + class VP_Tree{ + public: + //#==== Custom Type Definitions + //# * `Vectors` <- `std::vector` + //# + typedef typename std::vector Vectors; + private: + // + struct Node{ + size_t _index; + Scalar _threshold; + Node* _nearChild; + Node* _farChild; + Node(size_t __index); + }; + struct Answer{ + size_t _index; + Scalar _dist2; + // + Answer(size_t __index, Scalar const& __dist2); + Answer(Answer const& __answer2); + }; + class AnswerCompare{ + private: + Vectors const* _vectors; + bool _cmpValue; + public: + AnswerCompare(Vectors const* __vectors, bool __cmpValue); + bool operator()(Answer const& __a, Answer const& __b) const; + }; + typedef std::vector AnswerV; + typedef std::priority_queue Answers; + // + Vectors _vectors; + Node* _root; + size_t _dimension; + bool _needRebuild; + // + Scalar distance2(Vector const& __v1, Vector const& __v2) const; + int distanceCompare(Scalar const& __a2, Scalar const& __b2, + Scalar const& __c2) const; + Scalar split(ssize_t __first, ssize_t __last, size_t __order, + Vector const& __center); + // + Node* build(ssize_t __first, ssize_t __last); + void query(Vector const& __vector, + size_t __k, + AnswerCompare const& __cmp, + Node const* __node, + Answers* __out) const; + void clear(Node* __root); + Node* dup(Node* __root); + // + void print(Node* __node, int depth = 1, + Node* __parent = NULL, bool __near = true) const; + public: + VP_Tree(); + VP_Tree(VP_Tree const& __tree2); + VP_Tree(size_t __dimension); + ~VP_Tree(); + VP_Tree& operator=(VP_Tree const& __tree2); + + //#==== Support Methods + //# + //# * N <- `this` 中擁有的資料數 + //# * D <- `this` 資料維度 + //# + //#[options="header",width="100%",cols="1>m,3>s,7<,3<,3^,20<",grid="rows"] + //#|===================================================================== + //#|Const?|Name | Parameters | Return_Type| Time_Complexity| Description + + + //#||insert|(Vector const& `v`)|void| O(1) + //#|將向量 `v` 加到set中 + void insert(Vector const& __vector); + + + //#||erase|(Vector const& `v`)|bool| O(N) + //#|將向量 `v` 從set中移除, '~TODO:可以再優化~' + bool erase (Vector const& __vector); + + + //#||build|()|void|O(KN logN) or O(1) + //#|檢查距上一次 `build()` 至今是否有 `insert/erase` 被呼叫, + //# 若有, 重新建樹, 否則不做事 + void build(); + + + //#||forceBuild|()|void|O(KN logN) + //#|重新建樹 + void forceBuild(); + + + //#|const|query|(Vector const& `v`,\size_t `i`,\bool `cmp`)|Vectors + //#|O(KN ^1-1/K^ ) + //#|於set中找尋距離 `v` 前 `i` 近的向量, 並依照由近而遠的順序排序. + //# 如果有兩個向量 `v1`,`v2` 距離一樣, 且 `cmp` 為 `true` , 則直接依照 + //# `v1 < v2` 來決定誰在前面. 最後回傳一陣列包含所有解. + Vectors query(Vector const& __vector, + size_t __nearestNumber, + bool __compareWholeVector) const; + + + //#||clear|()|void|O(1) + //#|清空所有資料 + void clear(); + + + //#||reset|(size_t `dimension`)|size_t|O(1) + //#|清空所有資料並且指定維度為 `max(1, dimension)` 並且回傳指定後的維度 + size_t reset(size_t __dimension); + + + //#|===================================================================== + + void print() const; + }; + //# + //#[NOTE] + //#======================================== + //#======================================== + //# + //# ''' +} + +#include "VP_Tree.hpp" + +#endif // VP_Tree_H__ diff --git a/meowpp/dsa/VP_Tree.hpp b/meowpp/dsa/VP_Tree.hpp new file mode 100644 index 0000000..a3a0d82 --- /dev/null +++ b/meowpp/dsa/VP_Tree.hpp @@ -0,0 +1,336 @@ + +#include +#include +#include "../utility.h" + +namespace meow{ + ///////////////////// **# Node #** /////////////////////// + template + inline + VP_Tree::Node::Node(size_t __index): + _index(__index), _nearChild(NULL), _farChild(NULL){ + } + ///////////////////// **# Answer #** ///////////////////// + template + inline + VP_Tree::Answer::Answer(size_t __index, + Scalar const& __dist2): + _index(__index), _dist2(__dist2){ + } + template + inline + VP_Tree::Answer::Answer(Answer const& __answer2): + _index(__answer2._index), _dist2(__answer2._dist2){ + } + template + inline + VP_Tree::AnswerCompare::AnswerCompare + (Vectors const* __vectors, bool __cmpValue): + _vectors(__vectors), _cmpValue(__cmpValue){ + } + template + inline bool + VP_Tree::AnswerCompare::operator()(Answer const& __a, + Answer const& __b) const{ + if(__a._dist2 < __b._dist2) return true; + if(__b._dist2 < __a._dist2) return false; + return (_cmpValue && ((*_vectors)[__a._index] < (*_vectors)[__b._index])); + } + //////// **# distance2, distanceCompare, split #** /////// + template + inline Scalar + VP_Tree::distance2(Vector const& __v1, + Vector const& __v2) const{ + Scalar ret(0); + for(size_t i = 0; i < _dimension; i++) ret += squ(__v1[i] - __v2[i]); + return ret; + } + template + inline int + VP_Tree::distanceCompare(Scalar const& __a2, + Scalar const& __b2, + Scalar const& __c2) const{ + // test if sqrt(__a2) +- sqrt(|__b2|) <= sqrt(__c2) + //printf("abc = %lld %lld %lld\n", __a2, __b2, __c2); + if(__b2 < 0){ + return -distanceCompare(__c2, -__b2, __a2); + } + Scalar cab(__c2 - __a2 - __b2); + if(cab < Scalar(0)) return 1; + Scalar ab2(Scalar(4) * __a2 * __b2), cab2(squ(cab)); + if ( ab2 < cab2) return -1; + else if(cab2 < ab2) return 1; + else return 0; + } + template + inline Scalar + VP_Tree::split(ssize_t __first, ssize_t __last, + size_t __order, Vector const& __center){ + //printf("%ld %ld %lu\n", __first, __last, __order); + ssize_t first0 = __first; + ssize_t last0 = __last; + ssize_t order0 = __order; + std::vector dist2(__last - __first + 1); + for(ssize_t i = __first; i <= __last; i++){ + dist2[i - first0] = distance2(_vectors[i], __center); + } + while(__first < __last){ + size_t threshold_index = __first + rand() % (__last - __first + 1); + Scalar threshold(dist2[threshold_index - first0]); + /* + printf("range(%ld, %ld) dist2 = %3lld from %d\n", + __first - first0, __last - first0, + threshold, threshold_index - first0); + for(int i = first0; i <= last0; i++){ + if(i == __first) printf("+"); + if(i == threshold_index) printf("<"); + printf("<%lld,%lld,(%lld)>", _vectors[i][0], _vectors[i][1], + dist2[i - first0]); + if(i == threshold_index) printf(">"); + if(i == __last) printf("+"); + printf(" "); + } + printf("\n"); + // */ + size_t large_first = __last + 1; + for(size_t i = __first; __first <= large_first - 1; large_first--){ + if(threshold < dist2[large_first - 1 - first0]) continue; + while(i < large_first - 1 && !(threshold < dist2[i - first0])) i++; + if(i < large_first - 1){ + std::swap(dist2 [large_first - 1 - first0], dist2 [i - first0]); + std::swap(_vectors[large_first - 1 ], _vectors[i ]); + i++; + }else{ + break; + } + } + if(large_first == __last + 1){ + std::swap(dist2 [threshold_index-first0], dist2 [__last-first0]); + std::swap(_vectors[threshold_index ], _vectors[__last ]); + if(__order == __last - __first){ + __first = __last; + break; + } + __last--; + }else{ + if(__order < large_first - __first){ + __last = large_first - 1; + }else{ + __order -= large_first - __first; + __first = large_first; + } + } + } + /* + for(int i = first0; i <= last0; i++){ + if(i == __first) printf("+"); + if(i - first0 == order0) printf("<"); + printf("<%lld,%lld,(%lld)>", _vectors[i][0], _vectors[i][1], + dist2[i - first0]); + if(i - first0 == order0) printf(">"); + if(i == __first) printf("+"); + printf(" "); + } + printf("\n"); + printf("dist2(from<%lld,%lld>) = %lld\n", + __center[0], __center[1], + dist2[__first - first0]); + // */ + return dist2[__first - first0]; + } + ////////////////////// **# build() #** /////////////////// + template + inline typename VP_Tree::Node* + VP_Tree::build(ssize_t __first, ssize_t __last){ + if(__first > __last) return NULL; + Node* ret = new Node(__first); + if(__first < __last){ + std::swap(_vectors[__first], + _vectors[__first + rand() % (__last - __first + 1)]); + ssize_t mid = (__first + 1 + __last + 1) / 2; + ret->_threshold = split(__first + 1, __last, mid - (__first + 1), + _vectors[__first]); + ret->_nearChild = build(__first + 1, mid - 1 ); + ret->_farChild = build( mid , __last); + } + return ret; + } + ////////////////////// **# query() #** /////////////////// + template + inline void + VP_Tree::query(Vector const& __vector, + size_t __k, + AnswerCompare const& __cmp, + Node const* __node, + Answers* __out) const{ + if(__node == NULL) return ; + Scalar dist2 = distance2(__vector, _vectors[__node->_index]); + Answer my_ans(__node->_index, dist2); + if(__out->size() < __k || __cmp(my_ans, __out->top())){ + __out->push(my_ans); + if(__out->size() > __k){ + __out->pop(); + } + } + if(__node->_nearChild == NULL && __node->_farChild == NULL) return ; + if(__out->size() < __k || distanceCompare(dist2, -__out->top()._dist2, + __node->_threshold) <= 0){ + query(__vector, __k, __cmp, __node->_nearChild, __out); + } + if(__out->size() < __k || distanceCompare(dist2, __out->top()._dist2, + __node->_threshold) >= 0){ + query(__vector, __k, __cmp, __node->_farChild, __out); + } + } + ///////////////// **# clear(), dup() #** ///////////////// + template + inline void + VP_Tree::clear(Node* __root){ + if(__root == NULL) return ; + clear(__root->_nearChild); + clear(__root->_farChild); + delete __root; + } + template + inline typename VP_Tree::Node* + VP_Tree::dup(Node* __root){ + if(__root == NULL) return ; + Node* ret = new Node(__root->_index); + ret->_threshold = __root->_threshold; + ret->_nearChild = dup(__root->_nearChild); + ret->_farChild = dup(__root->_farChild ); + return ret; + } + /////////////////////// **# print #** //////////////////// + template + inline void + VP_Tree::print(Node* __node, int depth, + Node* __parent, bool __near) const{ + if(__node == NULL) return ; + printf("%*s%c)Me<%lld,%lld>, rad2 = %7lld", + depth * 2, "", + __near ? 'N' : 'F', + _vectors[__node->_index][0], _vectors[__node->_index][1], + __node->_threshold); + if(__parent != NULL){ + printf(" ---<%lld,%lld>:: %lld\n", + _vectors[__parent->_index][0], + _vectors[__parent->_index][1], + distance2(_vectors[__parent->_index], + _vectors[__node ->_index])); + }else{ + printf("\n"); + } + print(__node->_nearChild, depth + 1, __node, true ); + print(__node->_farChild , depth + 1, __node, false); + } + ///////// **# construre/destructure/copy oper #** //////// + template + inline + VP_Tree::VP_Tree(): + _root(NULL), _vectors(0), _dimension(0), _needRebuild(false){ + reset(0); + } + template + inline + VP_Tree::VP_Tree(VP_Tree const& __tree2): + _root(dup(__tree2._root)), + _vectors(__tree2._vectors), + _dimension(__tree2._dimension), + _needRebuild(__tree2._needRebuild){ + } + template + inline + VP_Tree::VP_Tree(size_t __dimension): + _root(NULL), _vectors(0), _dimension(0), _needRebuild(false){ + reset(__dimension); + } + template + inline + VP_Tree::~VP_Tree(){ + clear(_root); + } + template + inline VP_Tree& + VP_Tree::operator=(VP_Tree const& __tree2){ + reset(__tree2._dimension); + _vectors = __tree2._vectors; + _root = dup(__tree2._root); + _needRebuild = __tree2._needRebuild; + } + ////////////////// **# insert, erase #** ///////////////// + template + inline void + VP_Tree::insert(Vector const& __vector){ + _vectors.push_back(__vector); + _needRebuild = true; + } + template + inline bool + VP_Tree::erase(Vector const& __vector){ + for(ssize_t i = 0, I = _vectors.size(); i < I; i++){ + if(_vectors[i] == __vector){ + if(i != I - 1) std::swap(_vectors[i], _vectors[I - 1]); + _needRebuild = true; + _vectors.pop_back(); + return true; + } + } + return false; + } + ////////////////// **# build, forceBuild #** ///////////// + template + inline void + VP_Tree::build(){ + if(_needRebuild){ + forceBuild(); + } + } + template + inline void + VP_Tree::forceBuild(){ + _root = build(0, (size_t)_vectors.size() - 1); + _needRebuild = false; + } + ////////////////////// **# query #** ///////////////////// + template + inline typename VP_Tree::Vectors + VP_Tree::query(Vector const& __vector, + size_t __nearestNumber, + bool __compareWholeVector) const{ + ((VP_Tree*)this)->build(); + AnswerCompare cmp(&_vectors, __compareWholeVector); + Answers answers(cmp); + query(__vector, __nearestNumber, cmp, _root, &answers); + std::stack rev; + for( ; !answers.empty(); answers.pop()) rev.push(answers.top()); + Vectors ret; + for( ; !rev.empty(); rev.pop()) ret.push_back(_vectors[rev.top()._index]); + return ret; + } + /////////////////// **# clear, reset #** ///////////////// + template + void + VP_Tree::clear(){ + clear(_root); + _vectors.clear(); + _root = NULL; + _needRebuild = false; + } + template + size_t + VP_Tree::reset(size_t __dimension){ + clear(); + _dimension = std::max((size_t)1, __dimension); + return _dimension; + } + + /////////////////////// **# print #** //////////////////// + template + void inline + VP_Tree::print() const{ + printf("\nsize = %lu, dimension = %lu\n", _vectors.size(), _dimension); + print(_root, 1); + printf("\n\n"); + } +}; -- cgit v1.2.3