#include "meowpp/dsa/KD_Tree.h"
#include "meowpp/utility.h"
#include "dsa.h"
#include <vector>
#include <cmath>
#include <cstdlib>
#include <algorithm>
#include <ctime>
#include <queue>
static int N = 10000;
static int D = 5;
static double dist2(std::vector<double> const& v1, std::vector<double> const& v2){
double ret = 0;
for(int i = 0; i < D; i++){
ret += meow::squ(v1[i] - v2[i]);
}
return ret;
}
static std::vector< std::vector<double> > data;
static std::vector< double > dist;
static std::vector< int > order;
struct Answer{
double dist;
int id;
Answer(double _dist, int _id): dist(_dist), id(_id){ }
bool operator<(Answer const& b) const{
if(dist != b.dist) return (dist < b.dist);
return (id < b.id);
}
};
static void find(std::vector<double> const& v, int k){
std::priority_queue<Answer> qu;
for(int i = 0; i < k; i++){
qu.push(Answer(dist2(v, data[i]), i));
}
for(int i = k; i < N; i++){
qu.push(Answer(dist2(v, data[i]), i));
qu.pop();
}
order.resize(k);
for(int i = qu.size() - 1; i >= 0; i--){
order[i] = qu.top().id;
qu.pop();
}
}
static std::vector<double> v;
/*
static bool sf(const int& a, const int& b){
if(dist[a] != dist[b])
return (dist[a] < dist[b]);
return (a < b);
}
static void show(std::vector<double> const& ask, std::vector<int> kd, std::vector<int> me, int k){
if(N <= 30 && D <= 3){
printf("\nData:\n");
for(int i = 0; i < N; i++){
printf(" %2d) <", i);
for(int j = 0; j < D; j++){
printf("%.7f", data[i][j]);
if(j < D - 1) printf(", ");
else printf(">");
}
printf("\n");
}
printf("Ask <");
for(int j = 0; j < D; j++){
printf("%.7f", ask[j]);
if(j < D - 1) printf(", ");
else printf(">");
}
printf("\n");
printf("MyAnswer: ");
for(int i = 0; i < k; i++) printf("%d ", me[i]);
printf("\n");
printf("KdAnswer: ");
for(int i = 0; i < k; i++) printf("%d ", kd[i]);
printf("\n");
order.resize(N);
dist .resize(N);
for(int i = 0; i < N; i++){
dist [i] = dist2(ask, data[i]);
order[i] = i;
}
std::sort(order.begin(), order.end(), sf);
printf("Sorted:\n");
for(int i = 0; i < N; i++){
printf(" %2d) <", order[i]);
for(int j = 0; j < D; j++){
printf("%.7f", data[order[i]][j]);
if(j < D - 1) printf(", ");
else printf(">");
}
printf(" ((%.7f))", dist[order[i]]);
printf("\n");
}
}
}
// */
struct Node{
std::vector<double> v;
int id;
double& operator[](size_t d) { return v[d]; }
double operator[](size_t d) const { return v[d]; }
bool operator<(Node const& n) const{ return (id < n.id); }
};
TEST(KD_Tree, "It is very slow"){
int t0, t1, t2;
meow::KD_Tree<Node, double> tree(D);
meow::messagePrintf(1, "Create data (N = %d, D = %d)", N, D);
data.resize(N);
for(int i = 0; i < N; i++){
data[i].resize(D);
Node nd;
nd.v.resize(D);
nd.id = i;
for(int j = 0; j < D; j++){
data[i][j] = 12345.0 * (1.0 * rand() / RAND_MAX - 0.3);
nd[j] = data[i][j];
}
tree.insert(nd);
}
meow::messagePrintf(-1, "ok");
meow::messagePrintf(1, "build");
t0 = clock();
tree.build();
meow::messagePrintf(-1, "ok, %.3f seconds", (clock() - t0) * 1.0 / CLOCKS_PER_SEC);
meow::messagePrintf(1, "query...");
v.resize(D);
meow::KD_Tree<Node, double>::Vectors ret;
for(int k = 1; k <= std::min(100, N); k++){
meow::messagePrintf(1, "range k = %d", k);
t1 = t2 = 0;
for(int i = 0; i < 10; i++){
Node nd;
nd.v.resize(D);
for(int d = 0; d < D; d++){
v[d] = 12345.0 * (1.0 * rand() / RAND_MAX - 0.3);
nd[d] = v[d];
}
t0 = clock();
tree.build();
ret = tree.query(nd, k, true);
t1 += clock() - t0;
t0 = clock();
find(v, k);
t2 += clock() - t0;
if((int)ret.size() != (int)std::min(k, N)){
meow::messagePrintf(-1, "(%d)query fail, size error", i);
meow::messagePrintf(-1, "fail");
return false;
}
for(int kk = 1; kk <= k; kk++){
if(order[kk - 1] != ret[kk - 1].id){
//show(v, ret, order, k);
meow::messagePrintf(-1, "(%d)query fail", i);
meow::messagePrintf(-1, "fail");
return false;
}
}
}
meow::messagePrintf(-1, "ok %.3f/%.3f",
t1 * 1.0 / CLOCKS_PER_SEC,
t2 * 1.0 / CLOCKS_PER_SEC
);
}
meow::messagePrintf(-1, "ok");
return true;
}