aboutsummaryrefslogblamecommitdiffstats
path: root/meowpp.test/src/KD_Tree.cpp
blob: 8d4232eeb7f09fd1941a631a7c91a7dfe36c88cb (plain) (tree)





























































































































































































                                                                                                  
#include "meowpp/dsa/KD_Tree.h"
#include "meowpp/utility.h"

#include "dsa.h"

#include <vector>

#include <cmath>
#include <cstdlib>
#include <algorithm>
#include <ctime>
#include <queue>

static int N = 10000;
static int D = 5;

static double dist2(std::vector<double> const& v1, std::vector<double> const& v2){
  double ret = 0;
  for(int i = 0; i < D; i++){
    ret += meow::squ(v1[i] - v2[i]);
  }
  return ret;
}

static std::vector< std::vector<double> > data;
static std::vector<             double  > dist;
static std::vector<             int     > order;


struct Answer{
  double dist;
  int    id;
  Answer(double _dist, int _id): dist(_dist), id(_id){ }
  bool operator<(Answer const& b) const{
    if(dist != b.dist) return (dist < b.dist);
    return (id < b.id);
  }
};


static void find(std::vector<double> const& v, int k){
  std::priority_queue<Answer> qu;
  for(int i = 0; i < k; i++){
    qu.push(Answer(dist2(v, data[i]), i));
  }
  for(int i = k; i < N; i++){
    qu.push(Answer(dist2(v, data[i]), i));
    qu.pop();
  }
  order.resize(k);
  for(int i = qu.size() - 1; i >= 0; i--){
    order[i] = qu.top().id;
    qu.pop();
  }
}

static std::vector<double> v;

/*
static bool sf(const int& a, const int& b){
  if(dist[a] != dist[b])
    return (dist[a] < dist[b]);
  return (a < b);
}

static void show(std::vector<double> const& ask, std::vector<int> kd, std::vector<int> me, int k){
  if(N <= 30 && D <= 3){
    printf("\nData:\n");
    for(int i = 0; i < N; i++){
      printf("  %2d) <", i);
      for(int j = 0; j < D; j++){
        printf("%.7f", data[i][j]);
        if(j < D - 1) printf(", ");
        else          printf(">");
      }
      printf("\n");
    }
    printf("Ask  <");
    for(int j = 0; j < D; j++){
      printf("%.7f", ask[j]);
      if(j < D - 1) printf(", ");
      else          printf(">");
    }
    printf("\n");
    printf("MyAnswer: ");
    for(int i = 0; i < k; i++) printf("%d ", me[i]);
    printf("\n");
    printf("KdAnswer: ");
    for(int i = 0; i < k; i++) printf("%d ", kd[i]);
    printf("\n");
    order.resize(N);
    dist .resize(N);
    for(int i = 0; i < N; i++){
      dist [i] = dist2(ask, data[i]);
      order[i] = i;
    }
    std::sort(order.begin(), order.end(), sf);
    printf("Sorted:\n");
    for(int i = 0; i < N; i++){
      printf("  %2d) <", order[i]);
      for(int j = 0; j < D; j++){
        printf("%.7f", data[order[i]][j]);
        if(j < D - 1) printf(", ");
        else          printf(">");
      }
      printf(" ((%.7f))", dist[order[i]]);
      printf("\n");
    }
  }
}
// */

struct Node{
  std::vector<double> v;
  int id;
  double& operator[](size_t d)       { return v[d]; }
  double  operator[](size_t d) const { return v[d]; }
  bool operator<(Node const& n) const{ return (id < n.id); }
};

TEST(KD_Tree, "It is very slow"){
  
  int t0, t1, t2;
  
  meow::KD_Tree<Node, double> tree(D);
  
  meow::messagePrintf(1, "Create data (N = %d, D = %d)", N, D);
  data.resize(N);
  for(int i = 0; i < N; i++){
    data[i].resize(D);
    Node nd;
    nd.v.resize(D);
    nd.id = i;
    for(int j = 0; j < D; j++){
      data[i][j] = 12345.0 * (1.0 * rand() / RAND_MAX - 0.3);
      nd[j] = data[i][j];
    }
    tree.insert(nd);
  }
  meow::messagePrintf(-1, "ok");
  meow::messagePrintf(1, "build");
  t0 = clock();
  tree.build();
  meow::messagePrintf(-1, "ok, %.3f seconds", (clock() - t0) * 1.0 / CLOCKS_PER_SEC);
  
  meow::messagePrintf(1, "query...");
  v.resize(D);
  meow::KD_Tree<Node, double>::Vectors ret;
  for(int k = 1; k <= std::min(100, N); k++){
    meow::messagePrintf(1, "range k = %d", k);
    t1 = t2 = 0;
    for(int i = 0; i < 10; i++){
      Node nd;
      nd.v.resize(D);
      for(int d = 0; d < D; d++){
        v[d] = 12345.0 * (1.0 * rand() / RAND_MAX - 0.3);
        nd[d] = v[d];
      }
      t0 = clock();
      tree.build();
      ret = tree.query(nd, k, true);
      t1 += clock() - t0;
      
      t0 = clock();
      find(v, k);
      t2 += clock() - t0;
      if((int)ret.size() != (int)std::min(k, N)){
        meow::messagePrintf(-1, "(%d)query fail, size error", i);
        meow::messagePrintf(-1, "fail");
        return false;
      }
      for(int kk = 1; kk <= k; kk++){
        if(order[kk - 1] != ret[kk - 1].id){
          //show(v, ret, order, k);
          meow::messagePrintf(-1, "(%d)query fail", i);
          meow::messagePrintf(-1, "fail");
          return false;
        }
      }
    }
    meow::messagePrintf(-1, "ok %.3f/%.3f",
                        t1 * 1.0 / CLOCKS_PER_SEC,
                        t2 * 1.0 / CLOCKS_PER_SEC
                        );
  }
  meow::messagePrintf(-1, "ok");
  
  
  return true;
}