aboutsummaryrefslogblamecommitdiffstats
path: root/_test/meowpp_VP_Tree.cpp
blob: 34c979ce943a6f19184940dc00cc0142a910a889 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12











                      

                      





































































































































































                                                                                               
#include "meowpp.h"

#include <vector>

#include <cmath>
#include <cstdlib>
#include <algorithm>
#include <ctime>

#include <queue>

static int N = 100000;
static int D = 32;
static int MAX = 1000;

typedef long long lnt;

struct MyVector{
  std::vector<lnt> v;
  int              w;
  //
  MyVector(MyVector const& _v): v(_v.v), w(_v.w){ }
  MyVector(     ):v(D){ for(int i = 0; i < D; i++){ v[i] = (lnt)rand() % MAX; } }
  MyVector(lnt k):v(D){ for(int i = 0; i < D; i++){ v[i] = k; } }
  //
  lnt      & operator[](size_t n)      { return v[n]; }
  lnt const& operator[](size_t n) const{ return v[n]; }
  bool operator<(MyVector const& v2) const{ return (w < v2.w); }
  bool operator==(MyVector const& v2) const{
    for(int i = 0; i < D; i++) if(v[i] != v2[i]) return false;
    return (w == v2.w);
  }
};


static lnt dist2(MyVector const& v1, MyVector const& v2){
  lnt k = 0;
  for(int i = 0; i < D; i++){
    k += (v1[i] - v2[i]) * (v1[i] - v2[i]);
  }
  return k;
}

static std::vector<MyVector> data;

void show(MyVector const& v, std::vector<MyVector> const& r1, std::vector<MyVector> const& r2){
  if(N <= 20 && r1.size() <= 7){
    printf("\n");
    for(int i = 0; i < N; i++){
      printf("%3d) ", data[i].w);
      for(int j = 0; j < D; j++)
        printf("%8lld ", data[i][j]);
      printf(" ===> %lld\n", dist2(data[i], v));
    }
    printf("\n");
    printf("ask) ");
    for(int j = 0; j < D; j++)
      printf("%8lld ", v[j]);
    printf("\n");
    printf("---------\n");
    for(int i = 0; i < r1.size(); i++){
      printf("%3d) ", r1[i].w);
      for(int j = 0; j < D; j++)
        printf("%8lld ", r1[i][j]);
      printf(" ===> %lld\n", dist2(r1[i], v));
    }
    printf("---------\n");
    for(int i = 0; i < r2.size(); i++){
      printf("%3d) ", r2[i].w);
      for(int j = 0; j < D; j++)
        printf("%8lld ", r2[i][j]);
      printf(" ===> %lld\n", dist2(r2[i], v));
    }
  }
}

namespace VP{
  struct Answer{
    int i;
    lnt d;
    //
    Answer(int _i, lnt _d): i(_i), d(_d){ }
    Answer(Answer const& _a): i(_a.i), d(_a.d){ }
    //
    bool operator<(Answer const& b) const{
      if(d != b.d) return (d < b.d);
      else         return (data[i] < data[b.i]);
    }
  };
}

static std::vector<MyVector> find(MyVector const& v, int k){
  std::priority_queue<VP::Answer> qu;
  for(int i = 0; i < std::min(k, N); i++){
    qu.push(VP::Answer(i, dist2(v, data[i])));
  }
  for(int i = std::min(k, N); i < N; i++){
    qu.push(VP::Answer(i, dist2(v, data[i])));
    qu.pop();
  }
  std::vector<MyVector> ret(qu.size());
  for(int i = (ssize_t)qu.size() - 1; i >= 0; i--){
    ret[i] = data[qu.top().i];
    qu.pop();
  }
  return ret;
}

TEST(VP_Tree){
  int t0, t1, t2;
  
  meow::VP_Tree<MyVector, lnt> tree(D);
  
  meow::messagePrintf(1, "Create data (N = %d, D = %d)", N, D);
  data.resize(N);
  for(int i = 0; i < N; i++){
    if(i <= N / 10)
      data[i] = MyVector((lnt)i);
    else{
      for(int j = 0; j < D; j++){
        data[i][j] = rand() % MAX;
      }
    }
  }
  for(int i = 0; i < N; i++){
    data[i].w = i;
  }
  for(int i = 0; i < N; i++){
    tree.insert(data[i]);
  }
  meow::messagePrintf(-1, "ok");
  meow::messagePrintf(1, "build");
  t0 = clock();
  tree.build();
  //tree.print();
  meow::messagePrintf(-1, "ok, %.3f seconds", (clock() - t0) * 1.0 / CLOCKS_PER_SEC);
  
  meow::messagePrintf(1, "query...");
  meow::KD_Tree<MyVector, lnt>::Vectors ret1, ret2;
  for(int k = 1; k <= std::min(100, N); k++){
    meow::messagePrintf(1, "range k = %d", k);
    t1 = t2 = 0;
    for(int i = 0; i < 10; i++){
      MyVector ask;
      
      t0 = clock();
      tree.build();
      ret1 = tree.query(ask, k, true);
      t1 += clock() - t0;
      
      t0 = clock();
      ret2 = find(ask, k);
      t2 += clock() - t0;
      
      if(ret1.size() != ret2.size() && false){
        meow::messagePrintf(-1, "(%d)query fail, size error", i);
        meow::messagePrintf(-1, "fail");
        return false;
      }
      for(int kk = 0, KK = ret1.size(); kk < KK; kk++){
        if(ret1[kk] == ret2[kk]){
          continue;
        }
        show(ask, ret1, ret2);
        meow::messagePrintf(-1, "(%d)query fail", i);
        meow::messagePrintf(-1, "fail");
        return false;
      }
    }
    meow::messagePrintf(-1, "ok %.3f/%.3f",
                        t1 * 1.0 / CLOCKS_PER_SEC,
                        t2 * 1.0 / CLOCKS_PER_SEC
                        );
  }
  meow::messagePrintf(-1, "ok");
  
  
  return true;
};
;