diff options
Diffstat (limited to '_test')
-rw-r--r-- | _test/meowpp_VP_Tree.cpp | 180 |
1 files changed, 180 insertions, 0 deletions
diff --git a/_test/meowpp_VP_Tree.cpp b/_test/meowpp_VP_Tree.cpp new file mode 100644 index 0000000..8d5e903 --- /dev/null +++ b/_test/meowpp_VP_Tree.cpp @@ -0,0 +1,180 @@ +#include "meowpp.h" + +#include <vector> + +#include <cmath> +#include <cstdlib> +#include <algorithm> +#include <ctime> + +#include <queue> + +static int N = 100000; +static int D = 64; +static int MAX = 100; + +typedef long long lnt; + +struct MyVector{ + std::vector<lnt> v; + int w; + // + MyVector(MyVector const& _v): v(_v.v), w(_v.w){ } + MyVector( ):v(D){ for(int i = 0; i < D; i++){ v[i] = (lnt)rand() % MAX; } } + MyVector(lnt k):v(D){ for(int i = 0; i < D; i++){ v[i] = k; } } + // + lnt & operator[](size_t n) { return v[n]; } + lnt const& operator[](size_t n) const{ return v[n]; } + bool operator<(MyVector const& v2) const{ return (w < v2.w); } + bool operator==(MyVector const& v2) const{ + for(int i = 0; i < D; i++) if(v[i] != v2[i]) return false; + return (w == v2.w); + } +}; + + +static lnt dist2(MyVector const& v1, MyVector const& v2){ + lnt k = 0; + for(int i = 0; i < D; i++){ + k += (v1[i] - v2[i]) * (v1[i] - v2[i]); + } + return k; +} + +static std::vector<MyVector> data; + +void show(MyVector const& v, std::vector<MyVector> const& r1, std::vector<MyVector> const& r2){ + if(N <= 20 && r1.size() <= 7){ + printf("\n"); + for(int i = 0; i < N; i++){ + printf("%3d) ", data[i].w); + for(int j = 0; j < D; j++) + printf("%8lld ", data[i][j]); + printf(" ===> %lld\n", dist2(data[i], v)); + } + printf("\n"); + printf("ask) "); + for(int j = 0; j < D; j++) + printf("%8lld ", v[j]); + printf("\n"); + printf("---------\n"); + for(int i = 0; i < r1.size(); i++){ + printf("%3d) ", r1[i].w); + for(int j = 0; j < D; j++) + printf("%8lld ", r1[i][j]); + printf(" ===> %lld\n", dist2(r1[i], v)); + } + printf("---------\n"); + for(int i = 0; i < r2.size(); i++){ + printf("%3d) ", r2[i].w); + for(int j = 0; j < D; j++) + printf("%8lld ", r2[i][j]); + printf(" ===> %lld\n", dist2(r2[i], v)); + } + } +} + +namespace VP{ + struct Answer{ + int i; + lnt d; + // + Answer(int _i, lnt _d): i(_i), d(_d){ } + Answer(Answer const& _a): i(_a.i), d(_a.d){ } + // + bool operator<(Answer const& b) const{ + if(d != b.d) return (d < b.d); + else return (data[i] < data[b.i]); + } + }; +} + +static std::vector<MyVector> find(MyVector const& v, int k){ + std::priority_queue<VP::Answer> qu; + for(int i = 0; i < std::min(k, N); i++){ + qu.push(VP::Answer(i, dist2(v, data[i]))); + } + for(int i = std::min(k, N); i < N; i++){ + qu.push(VP::Answer(i, dist2(v, data[i]))); + qu.pop(); + } + std::vector<MyVector> ret(qu.size()); + for(int i = (ssize_t)qu.size() - 1; i >= 0; i--){ + ret[i] = data[qu.top().i]; + qu.pop(); + } + return ret; +} + +TEST(VP_Tree){ + int t0, t1, t2; + + meow::VP_Tree<MyVector, lnt> tree(D); + + meow::messagePrintf(1, "Create data (N = %d, D = %d)", N, D); + data.resize(N); + for(int i = 0; i < N; i++){ + if(i <= N / 10) + data[i] = MyVector((lnt)i); + else{ + for(int j = 0; j < D; j++){ + data[i][j] = rand() % MAX; + } + } + } + for(int i = 0; i < N; i++){ + data[i].w = i; + } + for(int i = 0; i < N; i++){ + tree.insert(data[i]); + } + meow::messagePrintf(-1, "ok"); + meow::messagePrintf(1, "build"); + t0 = clock(); + tree.build(); + //tree.print(); + meow::messagePrintf(-1, "ok, %.3f seconds", (clock() - t0) * 1.0 / CLOCKS_PER_SEC); + + meow::messagePrintf(1, "query..."); + meow::KD_Tree<MyVector, lnt>::Vectors ret1, ret2; + for(int k = 1; k <= std::min(100, N); k++){ + meow::messagePrintf(1, "range k = %d", k); + t1 = t2 = 0; + for(int i = 0; i < 10; i++){ + MyVector ask; + + t0 = clock(); + tree.build(); + ret1 = tree.query(ask, k, true); + t1 += clock() - t0; + + t0 = clock(); + ret2 = find(ask, k); + t2 += clock() - t0; + + if(ret1.size() != ret2.size() && false){ + meow::messagePrintf(-1, "(%d)query fail, size error", i); + meow::messagePrintf(-1, "fail"); + return false; + } + for(int kk = 0, KK = ret1.size(); kk < KK; kk++){ + if(ret1[kk] == ret2[kk]){ + continue; + } + show(ask, ret1, ret2); + meow::messagePrintf(-1, "(%d)query fail", i); + meow::messagePrintf(-1, "fail"); + return false; + } + } + meow::messagePrintf(-1, "ok %.3f/%.3f", + t1 * 1.0 / CLOCKS_PER_SEC, + t2 * 1.0 / CLOCKS_PER_SEC + ); + } + meow::messagePrintf(-1, "ok"); + + + return true; +}; +; |