#include "meowpp.h"
#include <vector>
#include <cmath>
#include <cstdlib>
#include <algorithm>
#include <ctime>
#include <queue>
static int N = 100000;
static int D = 32;
static int MAX = 1000;
typedef long long lnt;
struct MyVector{
std::vector<lnt> v;
int w;
//
MyVector(MyVector const& _v): v(_v.v), w(_v.w){ }
MyVector( ):v(D){ for(int i = 0; i < D; i++){ v[i] = (lnt)rand() % MAX; } }
MyVector(lnt k):v(D){ for(int i = 0; i < D; i++){ v[i] = k; } }
//
lnt & operator[](size_t n) { return v[n]; }
lnt const& operator[](size_t n) const{ return v[n]; }
bool operator<(MyVector const& v2) const{ return (w < v2.w); }
bool operator==(MyVector const& v2) const{
for(int i = 0; i < D; i++) if(v[i] != v2[i]) return false;
return (w == v2.w);
}
};
static lnt dist2(MyVector const& v1, MyVector const& v2){
lnt k = 0;
for(int i = 0; i < D; i++){
k += (v1[i] - v2[i]) * (v1[i] - v2[i]);
}
return k;
}
static std::vector<MyVector> data;
void show(MyVector const& v, std::vector<MyVector> const& r1, std::vector<MyVector> const& r2){
if(N <= 20 && r1.size() <= 7){
printf("\n");
for(int i = 0; i < N; i++){
printf("%3d) ", data[i].w);
for(int j = 0; j < D; j++)
printf("%8lld ", data[i][j]);
printf(" ===> %lld\n", dist2(data[i], v));
}
printf("\n");
printf("ask) ");
for(int j = 0; j < D; j++)
printf("%8lld ", v[j]);
printf("\n");
printf("---------\n");
for(int i = 0; i < r1.size(); i++){
printf("%3d) ", r1[i].w);
for(int j = 0; j < D; j++)
printf("%8lld ", r1[i][j]);
printf(" ===> %lld\n", dist2(r1[i], v));
}
printf("---------\n");
for(int i = 0; i < r2.size(); i++){
printf("%3d) ", r2[i].w);
for(int j = 0; j < D; j++)
printf("%8lld ", r2[i][j]);
printf(" ===> %lld\n", dist2(r2[i], v));
}
}
}
namespace VP{
struct Answer{
int i;
lnt d;
//
Answer(int _i, lnt _d): i(_i), d(_d){ }
Answer(Answer const& _a): i(_a.i), d(_a.d){ }
//
bool operator<(Answer const& b) const{
if(d != b.d) return (d < b.d);
else return (data[i] < data[b.i]);
}
};
}
static std::vector<MyVector> find(MyVector const& v, int k){
std::priority_queue<VP::Answer> qu;
for(int i = 0; i < std::min(k, N); i++){
qu.push(VP::Answer(i, dist2(v, data[i])));
}
for(int i = std::min(k, N); i < N; i++){
qu.push(VP::Answer(i, dist2(v, data[i])));
qu.pop();
}
std::vector<MyVector> ret(qu.size());
for(int i = (ssize_t)qu.size() - 1; i >= 0; i--){
ret[i] = data[qu.top().i];
qu.pop();
}
return ret;
}
TEST(VP_Tree){
int t0, t1, t2;
meow::VP_Tree<MyVector, lnt> tree(D);
meow::messagePrintf(1, "Create data (N = %d, D = %d)", N, D);
data.resize(N);
for(int i = 0; i < N; i++){
if(i <= N / 10)
data[i] = MyVector((lnt)i);
else{
for(int j = 0; j < D; j++){
data[i][j] = rand() % MAX;
}
}
}
for(int i = 0; i < N; i++){
data[i].w = i;
}
for(int i = 0; i < N; i++){
tree.insert(data[i]);
}
meow::messagePrintf(-1, "ok");
meow::messagePrintf(1, "build");
t0 = clock();
tree.build();
//tree.print();
meow::messagePrintf(-1, "ok, %.3f seconds", (clock() - t0) * 1.0 / CLOCKS_PER_SEC);
meow::messagePrintf(1, "query...");
meow::KD_Tree<MyVector, lnt>::Vectors ret1, ret2;
for(int k = 1; k <= std::min(100, N); k++){
meow::messagePrintf(1, "range k = %d", k);
t1 = t2 = 0;
for(int i = 0; i < 10; i++){
MyVector ask;
t0 = clock();
tree.build();
ret1 = tree.query(ask, k, true);
t1 += clock() - t0;
t0 = clock();
ret2 = find(ask, k);
t2 += clock() - t0;
if(ret1.size() != ret2.size() && false){
meow::messagePrintf(-1, "(%d)query fail, size error", i);
meow::messagePrintf(-1, "fail");
return false;
}
for(int kk = 0, KK = ret1.size(); kk < KK; kk++){
if(ret1[kk] == ret2[kk]){
continue;
}
show(ask, ret1, ret2);
meow::messagePrintf(-1, "(%d)query fail", i);
meow::messagePrintf(-1, "fail");
return false;
}
}
meow::messagePrintf(-1, "ok %.3f/%.3f",
t1 * 1.0 / CLOCKS_PER_SEC,
t2 * 1.0 / CLOCKS_PER_SEC
);
}
meow::messagePrintf(-1, "ok");
return true;
};
;