aboutsummaryrefslogtreecommitdiffstats
path: root/meowpp/dsa/KD_Tree.hpp
blob: 9e9a92547e55294e642eb91aa36409c1e5362ff1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
#include <cstdlib>
#include <list>
#include <vector>
#include <algorithm>
#include <set>
#include "utility.h"

namespace meow{
  template<class Keys, class Key, class Value>
  inline KD_Tree<Keys,Key,Value>::Node::Node(Keys _key,
                                             Value _value,
                                             ssize_t _l_child,
                                             ssize_t _r_child):
  key(_key), value(_value), lChild(_l_child), rChild(_r_child){ }
  //
  template<class Keys, class Key, class Value>
  inline KD_Tree<Keys, Key, Value>::Sorter::Sorter(KD_Tree::Nodes const& _nodes,
                                                   size_t _cmp):
  nodes(_nodes), cmp(_cmp){ }
  template<class Keys, class Key, class Value>
  inline bool
  KD_Tree<Keys, Key, Value>::Sorter::operator()(size_t const& a,
                                                size_t const& b) const{
    if(nodes[a].key[cmp] != nodes[b].key[cmp]){
      return (nodes[a].key[cmp] < nodes[b].key[cmp]);
    }
    return (nodes[a].value < nodes[b].value);
  }
  //
  template<class Keys, class Key, class Value>
  inline KD_Tree<Keys, Key, Value>::Answer::Answer(Node const& _node,
                                                   Key _dist2):
  node(_node), dist2(_dist2){ }
  template<class Keys, class Key, class Value>
  inline bool
  KD_Tree<Keys, Key, Value>::Answer::operator<(Answer const& b) const{
    if(dist2 != b.dist2) return (dist2 < b.dist2);
    return (node.value < b.node.value);
  }
  //
  template<class Keys, class Key, class Value>
  inline Key KD_Tree<Keys, Key, Value>::distance2(Keys const& k1,
                                                  Keys const& k2) const{
    Key ret(0);
    for(size_t i = 0; i < dimension; i++)
      ret += squ(k1[i] - k2[i]);
    return ret;
  }
  template<class Keys, class Key, class Value>
  inline size_t KD_Tree<Keys, Key, Value>::query(Keys const& key,
                                                 size_t k,
                                                 size_t index,
                                                 int depth,
                                                 std::vector<Key>& dist2_v,
                                                 Key               dist2_s,
                                                 KD_Tree::AnswerList* ret) const{
    if(index == NIL){
      return 0;
    }
    size_t cmp = depth % dimension;
    ssize_t right_side, opposite, size;
    ssize_t sz, other;
    if(key[cmp] <= nodes[index].key[cmp]){
      right_side = nodes[index].lChild;
      opposite   = nodes[index].rChild;
    }else{
      right_side = nodes[index].rChild;
      opposite   = nodes[index].lChild;
    }
    size = query(key, k, right_side, depth + 1, dist2_v, dist2_s, ret);
    Answer my_ans(nodes[index], distance2(nodes[index].key, key));
    if(size < k || my_ans < *(ret->rbegin())){
      KD_Tree::AnswerListIterator it = ret->begin();
      while(it != ret->end() && !(my_ans < *it)) it++;
      ret->insert(it, my_ans);
      size++;
    }
    Key dist2_old = dist2_v[cmp];
    dist2_v[cmp] = squ(nodes[index].key[cmp] - key[cmp]);
    dist2_s     += dist2_v[cmp] - dist2_old;
    if(size < k || (*(ret->rbegin())).dist2 >= dist2_s){
      KD_Tree::AnswerList ret2;
      size += query(key, k, opposite, depth + 1, dist2_v, dist2_s, &ret2);
      KD_Tree::AnswerListIterator it1, it2;
      for(it1 = ret->begin(), it2 = ret2.begin(); it2 != ret2.end(); it2++){
        while(it1 != ret->end() && *it1 < *it2) it1++;
        it1 = ++(ret->insert(it1, *it2));
      }
    }
    if(size > k){
      for(int i = size - k; i--; ){
        ret->pop_back();
      }
      size = k;
    }
    dist2_v[cmp] = dist2_old;
    return size;
  }
  template<class Keys, class Key, class Value>
  inline ssize_t KD_Tree<Keys, Key, Value>::build(ssize_t beg,
                                                  ssize_t end,
                                                  std::vector<size_t>* orders,
                                                  int depth){
    if(beg > end){
      return NIL;
    }
    ssize_t mid = (beg + end) / 2;
    size_t  cmp = depth % 2;
    std::set<size_t> right;
    for(ssize_t i = mid + 1; i <= end; i++){
      right.insert(orders[cmp][i]);
    }
    for(int i = 0; i < dimension; i++){
      if(i == cmp) continue;
      size_t aa = beg, bb = mid + 1;
      for(int j = beg; j <= end; j++){
        if(orders[i][j] == orders[cmp][mid]){
          orders[dimension][mid] = orders[i][j];
        }else if(right.find(orders[i][j]) != right.end()){
          orders[dimension][bb++] = orders[i][j];
        }else{
          orders[dimension][aa++] = orders[i][j];
        }
      }
      for(int j = beg; j <= end; j++){
        orders[i][j] = orders[dimension][j];
      }
    }
    nodes[orders[cmp][mid]].lChild = build(beg, mid - 1, orders, depth + 1);
    nodes[orders[cmp][mid]].rChild = build(mid + 1, end, orders, depth + 1);
    return orders[cmp][mid];
  }
  template<class Keys, class Key, class Value>
  inline KD_Tree<Keys, Key, Value>::KD_Tree():
  NIL(-1), root(NIL), needRebuild(false), dimension(1){ }
  template<class Keys, class Key, class Value>
  inline KD_Tree<Keys, Key, Value>::KD_Tree(size_t _dimension):
  NIL(-1), root(NIL), needRebuild(false), dimension(_dimension){ }
  template<class Keys, class Key, class Value>
  inline KD_Tree<Keys, Key, Value>::~KD_Tree(){ }
  template<class Keys, class Key, class Value>
  inline void KD_Tree<Keys, Key, Value>::insert(Keys const& key, Value value){
    nodes.push_back(Node(key, value, NIL, NIL));
    needRebuild = true;
  }
  template<class Keys, class Key, class Value>
  inline void KD_Tree<Keys, Key, Value>::build(){
    if(needRebuild){
      std::vector<size_t> orders[dimension + 1];
      for(int j = 0; j < dimension + 1; j++){
        orders[j].resize(nodes.size());
      }
      for(int j = 0; j < dimension; j++){
        for(size_t i = 0, I = nodes.size(); i < I; i++){
          orders[j][i] = i;
        }
        std::sort(orders[j].begin(), orders[j].end(), Sorter(nodes, j));
      }
      root = build(0, (ssize_t)nodes.size() - 1, orders, 0);
      needRebuild = false;
    }
  }
  template<class Keys, class Key, class Value>
  inline Value KD_Tree<Keys, Key, Value>::query(Keys const& point, int k) const{
    ((KD_Tree*)this)->build();
    KD_Tree::AnswerList ret;
    std::vector<Key> tmp(dimension, Key(0));
    query(point, k, root, 0, tmp, Key(0), &ret);
    return (*(ret.rbegin())).node.value;
  }
  template<class Keys, class Key, class Value>
  inline typename KD_Tree<Keys, Key, Value>::Values
  KD_Tree<Keys, Key, Value>::rangeQuery(Keys const& point, int k) const{
    ((KD_Tree*)this)->build();
    KD_Tree::AnswerList ret;
    std::vector<Key> tmp(dimension, Key(0));
    query(point, k, root, 0, tmp, Key(0), &ret);
    KD_Tree::Values ret_val(ret.size());
    int i = 0;
    for(KD_Tree::AnswerListIterator it = ret.begin(); it != ret.end(); it++){
      ret_val[i++] = (*it).node.value;
    }
    return ret_val;
  }
  template<class Keys, class Key, class Value>
  inline void KD_Tree<Keys, Key, Value>::clear(){
    root = NIL;
    nodes.clear();
    needRebuild = false;
  }
  template<class Keys, class Key, class Value>
  inline void KD_Tree<Keys, Key, Value>::reset(size_t _dimension){
    clear();
    dimension = _dimension;
  }
}