#include "meowpp/dsa/KD_Tree.h" #include "meowpp/utility.h" #include "meowpp.h" #include #include #include #include #include #include static int N = 10000; static int D = 5; static double dist2(std::vector const& v1, std::vector const& v2){ double ret = 0; for(int i = 0; i < D; i++){ ret += meow::squ(v1[i] - v2[i]); } return ret; } static std::vector< std::vector > data; static std::vector< double > dist; static std::vector< int > order; struct Answer{ double dist; int id; Answer(double _dist, int _id): dist(_dist), id(_id){ } bool operator<(Answer const& b) const{ if(dist != b.dist) return (dist < b.dist); return (id < b.id); } }; static void find(std::vector const& v, int k){ std::priority_queue qu; for(int i = 0; i < k; i++){ qu.push(Answer(dist2(v, data[i]), i)); } for(int i = k; i < N; i++){ qu.push(Answer(dist2(v, data[i]), i)); qu.pop(); } order.resize(k); for(int i = qu.size() - 1; i >= 0; i--){ order[i] = qu.top().id; qu.pop(); } } static std::vector v; static bool sf(const int& a, const int& b){ if(dist[a] != dist[b]) return (dist[a] < dist[b]); return (a < b); } static void show(std::vector const& ask, std::vector kd, std::vector me, int k){ if(N <= 30 && D <= 3){ printf("\nData:\n"); for(int i = 0; i < N; i++){ printf(" %2d) <", i); for(int j = 0; j < D; j++){ printf("%.7f", data[i][j]); if(j < D - 1) printf(", "); else printf(">"); } printf("\n"); } printf("Ask <"); for(int j = 0; j < D; j++){ printf("%.7f", ask[j]); if(j < D - 1) printf(", "); else printf(">"); } printf("\n"); printf("MyAnswer: "); for(int i = 0; i < k; i++) printf("%d ", me[i]); printf("\n"); printf("KdAnswer: "); for(int i = 0; i < k; i++) printf("%d ", kd[i]); printf("\n"); order.resize(N); dist .resize(N); for(int i = 0; i < N; i++){ dist [i] = dist2(ask, data[i]); order[i] = i; } std::sort(order.begin(), order.end(), sf); printf("Sorted:\n"); for(int i = 0; i < N; i++){ printf(" %2d) <", order[i]); for(int j = 0; j < D; j++){ printf("%.7f", data[order[i]][j]); if(j < D - 1) printf(", "); else printf(">"); } printf(" ((%.7f))", dist[order[i]]); printf("\n"); } } } struct Node{ std::vector v; int id; double& operator[](size_t d) { return v[d]; } double operator[](size_t d) const { return v[d]; } bool operator<(Node const& n) const{ return (id < n.id); } }; TEST(KD_Tree){ int t0, t1, t2; meow::KD_Tree tree(D); meow::messagePrintf(1, "Create data (N = %d, D = %d)", N, D); data.resize(N); for(int i = 0; i < N; i++){ data[i].resize(D); Node nd; nd.v.resize(D); nd.id = i; for(int j = 0; j < D; j++){ data[i][j] = 12345.0 * (1.0 * rand() / RAND_MAX - 0.3); nd[j] = data[i][j]; } tree.insert(nd); } meow::messagePrintf(-1, "ok"); meow::messagePrintf(1, "build"); t0 = clock(); tree.build(); meow::messagePrintf(-1, "ok, %.3f seconds", (clock() - t0) * 1.0 / CLOCKS_PER_SEC); meow::messagePrintf(1, "query..."); v.resize(D); meow::KD_Tree::Vectors ret; int id; for(int k = 1; k <= std::min(100, N); k++){ meow::messagePrintf(1, "range k = %d", k); t1 = t2 = 0; for(int i = 0; i < 10; i++){ Node nd; nd.v.resize(D); for(int d = 0; d < D; d++){ v[d] = 12345.0 * (1.0 * rand() / RAND_MAX - 0.3); nd[d] = v[d]; } t0 = clock(); tree.build(); ret = tree.query(nd, k, true); t1 += clock() - t0; t0 = clock(); find(v, k); t2 += clock() - t0; if(ret.size() != std::min(k, N)){ meow::messagePrintf(-1, "(%d)query fail, size error", i); meow::messagePrintf(-1, "fail"); return false; } for(int kk = 1; kk <= k; kk++){ if(order[kk - 1] != ret[kk - 1].id){ //show(v, ret, order, k); meow::messagePrintf(-1, "(%d)query fail", i); meow::messagePrintf(-1, "fail"); return false; } } } meow::messagePrintf(-1, "ok %.3f/%.3f", t1 * 1.0 / CLOCKS_PER_SEC, t2 * 1.0 / CLOCKS_PER_SEC ); } meow::messagePrintf(-1, "ok"); return true; };