aboutsummaryrefslogtreecommitdiffstats
path: root/_test/meowpp_KD_Tree.cpp
diff options
context:
space:
mode:
Diffstat (limited to '_test/meowpp_KD_Tree.cpp')
-rw-r--r--_test/meowpp_KD_Tree.cpp186
1 files changed, 186 insertions, 0 deletions
diff --git a/_test/meowpp_KD_Tree.cpp b/_test/meowpp_KD_Tree.cpp
new file mode 100644
index 0000000..dcbda5f
--- /dev/null
+++ b/_test/meowpp_KD_Tree.cpp
@@ -0,0 +1,186 @@
+#include "meowpp.h"
+
+#include <vector>
+
+#include <cmath>
+#include <cstdlib>
+#include <algorithm>
+#include <ctime>
+#include <queue>
+
+static int N = 10000;
+static int D = 5;
+
+static double dist2(std::vector<double> const& v1, std::vector<double> const& v2){
+ double ret = 0;
+ for(int i = 0; i < D; i++){
+ ret += meow::squ(v1[i] - v2[i]);
+ }
+ return ret;
+}
+
+static std::vector< std::vector<double> > data;
+static std::vector< double > dist;
+static std::vector< int > order;
+
+
+struct Answer{
+ double dist;
+ int id;
+ Answer(double _dist, int _id): dist(_dist), id(_id){ }
+ bool operator<(Answer const& b) const{
+ if(dist != b.dist) return (dist < b.dist);
+ return (id < b.id);
+ }
+};
+
+
+static void find(std::vector<double> const& v, int k){
+ std::priority_queue<Answer> qu;
+ for(int i = 0; i < k; i++){
+ qu.push(Answer(dist2(v, data[i]), i));
+ }
+ for(int i = k; i < N; i++){
+ qu.push(Answer(dist2(v, data[i]), i));
+ qu.pop();
+ }
+ order.resize(k);
+ for(int i = qu.size() - 1; i >= 0; i--){
+ order[i] = qu.top().id;
+ qu.pop();
+ }
+}
+
+static std::vector<double> v;
+
+static bool sf(const int& a, const int& b){
+ if(dist[a] != dist[b])
+ return (dist[a] < dist[b]);
+ return (a < b);
+}
+
+static void show(std::vector<double> const& ask, std::vector<int> kd, std::vector<int> me, int k){
+ if(N <= 30 && D <= 3){
+ printf("\nData:\n");
+ for(int i = 0; i < N; i++){
+ printf(" %2d) <", i);
+ for(int j = 0; j < D; j++){
+ printf("%.7f", data[i][j]);
+ if(j < D - 1) printf(", ");
+ else printf(">");
+ }
+ printf("\n");
+ }
+ printf("Ask <");
+ for(int j = 0; j < D; j++){
+ printf("%.7f", ask[j]);
+ if(j < D - 1) printf(", ");
+ else printf(">");
+ }
+ printf("\n");
+ printf("MyAnswer: ");
+ for(int i = 0; i < k; i++) printf("%d ", me[i]);
+ printf("\n");
+ printf("KdAnswer: ");
+ for(int i = 0; i < k; i++) printf("%d ", kd[i]);
+ printf("\n");
+ order.resize(N);
+ dist .resize(N);
+ for(int i = 0; i < N; i++){
+ dist [i] = dist2(ask, data[i]);
+ order[i] = i;
+ }
+ std::sort(order.begin(), order.end(), sf);
+ printf("Sorted:\n");
+ for(int i = 0; i < N; i++){
+ printf(" %2d) <", order[i]);
+ for(int j = 0; j < D; j++){
+ printf("%.7f", data[order[i]][j]);
+ if(j < D - 1) printf(", ");
+ else printf(">");
+ }
+ printf(" ((%.7f))", dist[order[i]]);
+ printf("\n");
+ }
+ }
+}
+
+struct Node{
+ std::vector<double> v;
+ int id;
+ double& operator[](size_t d) { return v[d]; }
+ double operator[](size_t d) const { return v[d]; }
+ bool operator<(Node const& n) const{ return (id < n.id); }
+};
+
+TEST(KD_Tree){
+
+ int t0, t1, t2;
+
+ meow::KD_Tree<Node, double> tree(D);
+
+ meow::messagePrintf(1, "Create data (N = %d, D = %d)", N, D);
+ data.resize(N);
+ for(int i = 0; i < N; i++){
+ data[i].resize(D);
+ Node nd;
+ nd.v.resize(D);
+ nd.id = i;
+ for(int j = 0; j < D; j++){
+ data[i][j] = 12345.0 * (1.0 * rand() / RAND_MAX - 0.3);
+ nd[j] = data[i][j];
+ }
+ tree.insert(nd);
+ }
+ meow::messagePrintf(-1, "ok");
+ meow::messagePrintf(1, "build");
+ t0 = clock();
+ tree.build();
+ meow::messagePrintf(-1, "ok, %.3f seconds", (clock() - t0) * 1.0 / CLOCKS_PER_SEC);
+
+ meow::messagePrintf(1, "query...");
+ v.resize(D);
+ meow::KD_Tree<Node, double>::Vectors ret;
+ int id;
+ for(int k = 1; k <= std::min(100, N); k++){
+ meow::messagePrintf(1, "range k = %d", k);
+ t1 = t2 = 0;
+ for(int i = 0; i < 10; i++){
+ Node nd;
+ nd.v.resize(D);
+ for(int d = 0; d < D; d++){
+ v[d] = 12345.0 * (1.0 * rand() / RAND_MAX - 0.3);
+ nd[d] = v[d];
+ }
+ t0 = clock();
+ tree.build();
+ ret = tree.query(nd, k, true);
+ t1 += clock() - t0;
+
+ t0 = clock();
+ find(v, k);
+ t2 += clock() - t0;
+ if(ret.size() != std::min(k, N)){
+ meow::messagePrintf(-1, "(%d)query fail, size error", i);
+ meow::messagePrintf(-1, "fail");
+ return false;
+ }
+ for(int kk = 1; kk <= k; kk++){
+ if(order[kk - 1] != ret[kk - 1].id){
+ //show(v, ret, order, k);
+ meow::messagePrintf(-1, "(%d)query fail", i);
+ meow::messagePrintf(-1, "fail");
+ return false;
+ }
+ }
+ }
+ meow::messagePrintf(-1, "ok %.3f/%.3f",
+ t1 * 1.0 / CLOCKS_PER_SEC,
+ t2 * 1.0 / CLOCKS_PER_SEC
+ );
+ }
+ meow::messagePrintf(-1, "ok");
+
+
+ return true;
+};