diff options
Diffstat (limited to '_test/meowpp_KD_Tree.cpp')
-rw-r--r-- | _test/meowpp_KD_Tree.cpp | 186 |
1 files changed, 186 insertions, 0 deletions
diff --git a/_test/meowpp_KD_Tree.cpp b/_test/meowpp_KD_Tree.cpp new file mode 100644 index 0000000..dcbda5f --- /dev/null +++ b/_test/meowpp_KD_Tree.cpp @@ -0,0 +1,186 @@ +#include "meowpp.h" + +#include <vector> + +#include <cmath> +#include <cstdlib> +#include <algorithm> +#include <ctime> +#include <queue> + +static int N = 10000; +static int D = 5; + +static double dist2(std::vector<double> const& v1, std::vector<double> const& v2){ + double ret = 0; + for(int i = 0; i < D; i++){ + ret += meow::squ(v1[i] - v2[i]); + } + return ret; +} + +static std::vector< std::vector<double> > data; +static std::vector< double > dist; +static std::vector< int > order; + + +struct Answer{ + double dist; + int id; + Answer(double _dist, int _id): dist(_dist), id(_id){ } + bool operator<(Answer const& b) const{ + if(dist != b.dist) return (dist < b.dist); + return (id < b.id); + } +}; + + +static void find(std::vector<double> const& v, int k){ + std::priority_queue<Answer> qu; + for(int i = 0; i < k; i++){ + qu.push(Answer(dist2(v, data[i]), i)); + } + for(int i = k; i < N; i++){ + qu.push(Answer(dist2(v, data[i]), i)); + qu.pop(); + } + order.resize(k); + for(int i = qu.size() - 1; i >= 0; i--){ + order[i] = qu.top().id; + qu.pop(); + } +} + +static std::vector<double> v; + +static bool sf(const int& a, const int& b){ + if(dist[a] != dist[b]) + return (dist[a] < dist[b]); + return (a < b); +} + +static void show(std::vector<double> const& ask, std::vector<int> kd, std::vector<int> me, int k){ + if(N <= 30 && D <= 3){ + printf("\nData:\n"); + for(int i = 0; i < N; i++){ + printf(" %2d) <", i); + for(int j = 0; j < D; j++){ + printf("%.7f", data[i][j]); + if(j < D - 1) printf(", "); + else printf(">"); + } + printf("\n"); + } + printf("Ask <"); + for(int j = 0; j < D; j++){ + printf("%.7f", ask[j]); + if(j < D - 1) printf(", "); + else printf(">"); + } + printf("\n"); + printf("MyAnswer: "); + for(int i = 0; i < k; i++) printf("%d ", me[i]); + printf("\n"); + printf("KdAnswer: "); + for(int i = 0; i < k; i++) printf("%d ", kd[i]); + printf("\n"); + order.resize(N); + dist .resize(N); + for(int i = 0; i < N; i++){ + dist [i] = dist2(ask, data[i]); + order[i] = i; + } + std::sort(order.begin(), order.end(), sf); + printf("Sorted:\n"); + for(int i = 0; i < N; i++){ + printf(" %2d) <", order[i]); + for(int j = 0; j < D; j++){ + printf("%.7f", data[order[i]][j]); + if(j < D - 1) printf(", "); + else printf(">"); + } + printf(" ((%.7f))", dist[order[i]]); + printf("\n"); + } + } +} + +struct Node{ + std::vector<double> v; + int id; + double& operator[](size_t d) { return v[d]; } + double operator[](size_t d) const { return v[d]; } + bool operator<(Node const& n) const{ return (id < n.id); } +}; + +TEST(KD_Tree){ + + int t0, t1, t2; + + meow::KD_Tree<Node, double> tree(D); + + meow::messagePrintf(1, "Create data (N = %d, D = %d)", N, D); + data.resize(N); + for(int i = 0; i < N; i++){ + data[i].resize(D); + Node nd; + nd.v.resize(D); + nd.id = i; + for(int j = 0; j < D; j++){ + data[i][j] = 12345.0 * (1.0 * rand() / RAND_MAX - 0.3); + nd[j] = data[i][j]; + } + tree.insert(nd); + } + meow::messagePrintf(-1, "ok"); + meow::messagePrintf(1, "build"); + t0 = clock(); + tree.build(); + meow::messagePrintf(-1, "ok, %.3f seconds", (clock() - t0) * 1.0 / CLOCKS_PER_SEC); + + meow::messagePrintf(1, "query..."); + v.resize(D); + meow::KD_Tree<Node, double>::Vectors ret; + int id; + for(int k = 1; k <= std::min(100, N); k++){ + meow::messagePrintf(1, "range k = %d", k); + t1 = t2 = 0; + for(int i = 0; i < 10; i++){ + Node nd; + nd.v.resize(D); + for(int d = 0; d < D; d++){ + v[d] = 12345.0 * (1.0 * rand() / RAND_MAX - 0.3); + nd[d] = v[d]; + } + t0 = clock(); + tree.build(); + ret = tree.query(nd, k, true); + t1 += clock() - t0; + + t0 = clock(); + find(v, k); + t2 += clock() - t0; + if(ret.size() != std::min(k, N)){ + meow::messagePrintf(-1, "(%d)query fail, size error", i); + meow::messagePrintf(-1, "fail"); + return false; + } + for(int kk = 1; kk <= k; kk++){ + if(order[kk - 1] != ret[kk - 1].id){ + //show(v, ret, order, k); + meow::messagePrintf(-1, "(%d)query fail", i); + meow::messagePrintf(-1, "fail"); + return false; + } + } + } + meow::messagePrintf(-1, "ok %.3f/%.3f", + t1 * 1.0 / CLOCKS_PER_SEC, + t2 * 1.0 / CLOCKS_PER_SEC + ); + } + meow::messagePrintf(-1, "ok"); + + + return true; +}; |