Templates -- Meow  1.2.9
A C++ template contains kinds of interesting classes and functions
VP_Tree.h
Go to the documentation of this file.
1 #ifndef dsa_VP_Tree_H__
2 #define dsa_VP_Tree_H__
3 
4 #include "../math/utility.h"
5 
6 #include <cstdlib>
7 
8 #include <list>
9 #include <vector>
10 #include <stack>
11 #include <queue>
12 
13 namespace meow {
14 
50 template<class Vector, class Scalar>
51 class VP_Tree {
52 public:
53  typedef std::vector<Vector> Vectors;
54 private:
55  struct Node {
56  size_t index_;
57  Scalar threshold_;
58  Node* nearChild_;
59  Node* farChild_;
60  //
61  Node(size_t index): index_(index), nearChild_(NULL), farChild_(NULL){
62  }
63  };
64  struct Answer {
65  size_t index_;
66  Scalar dist2_;
67  //
68  Answer(size_t index, Scalar const& dist2): index_(index), dist2_(dist2){
69  }
70  Answer(Answer const& answer2):
71  index_(answer2.index_), dist2_(answer2.dist2_){
72  }
73  };
74  class AnswerCompare {
75  private:
76  Vectors const* vectors_;
77  bool cmpValue_;
78  public:
79  AnswerCompare(Vectors const* vectors, bool cmpValue):
80  vectors_(vectors), cmpValue_(cmpValue){
81  }
82  bool operator()(Answer const& a, Answer const& b) const {
83  if (a.dist2_ < b.dist2_) return true;
84  if (b.dist2_ < a.dist2_) return false;
85  return (cmpValue_ && ((*vectors_)[a.index_] < (*vectors_)[b.index_]));
86  }
87  };
88  typedef std::vector<Answer> AnswerV;
89  typedef std::priority_queue<Answer, AnswerV, AnswerCompare> Answers;
90 
91  Vectors vectors_;
92  Node* root_;
93  size_t dimension_;
94  bool needRebuild_;
95 
96  Scalar distance2(Vector const& v1, Vector const& v2) const {
97  Scalar ret(0);
98  for (size_t i = 0; i < dimension_; i++) ret += squ(v1[i] - v2[i]);
99  return ret;
100  }
101  int distanceCompare(Scalar const& a2, Scalar const& b2,
102  Scalar const& c2) const {
103  if (b2 < 0) {
104  return -distanceCompare(c2, -b2, a2);
105  }
106  Scalar cab(c2 - a2 - b2);
107  if (cab < Scalar(0)) return 1;
108  Scalar ab2(Scalar(4) * a2 * b2), cab2(squ(cab));
109  if ( ab2 < cab2) return -1;
110  else if (cab2 < ab2) return 1;
111  else return 0;
112  }
113  Scalar split(ssize_t first, ssize_t last, size_t order,
114  Vector const& center) {
115  ssize_t first0 = first;
116  std::vector<Scalar> dist2(last - first + 1);
117  for (ssize_t i = first; i <= last; i++) {
118  dist2[i - first0] = distance2(vectors_[i], center);
119  }
120  while (first < last) {
121  size_t thresholdindex_ = first + rand() % (last - first + 1);
122  Scalar threshold(dist2[thresholdindex_ - first0]);
123  size_t large_first = last + 1;
124  for( ssize_t i=first; first<=(ssize_t)large_first-1; large_first--) {
125  if (threshold < dist2[large_first - 1 - first0]) continue;
126  while (i < (ssize_t)large_first-1&&!(threshold < dist2[i-first0])) i++;
127  if (i < (ssize_t)large_first - 1){
128  std::swap(dist2 [large_first - 1 - first0], dist2 [i - first0]);
129  std::swap(vectors_[large_first - 1 ], vectors_[i ]);
130  i++;
131  }
132  else {
133  break;
134  }
135  }
136  if (large_first == (size_t)last + 1) {
137  std::swap(dist2 [thresholdindex_-first0], dist2 [last-first0]);
138  std::swap(vectors_[thresholdindex_ ], vectors_[last ]);
139  if ((ssize_t)order == last - first) {
140  first = last;
141  break;
142  }
143  last--;
144  }
145  else {
146  if (order < large_first - first) {
147  last = large_first - 1;
148  }
149  else {
150  order -= large_first - first;
151  first = large_first;
152  }
153  }
154  }
155  return dist2[first - first0];
156  }
157  //
158  Node* build(ssize_t first, ssize_t last) {
159  if (first > last) return NULL;
160  Node* ret = new Node(first);
161  if (first < last) {
162  std::swap(vectors_[first],
163  vectors_[first + rand() % (last - first + 1)]);
164  ssize_t mid = (first + 1 + last + 1) / 2;
165  ret->threshold_ = split(first + 1, last, mid - (first + 1),
166  vectors_[first]);
167  ret->nearChild_ = build(first + 1, mid - 1 );
168  ret->farChild_ = build( mid , last);
169  }
170  return ret;
171  }
172  void query(Vector const& vector,
173  size_t k,
174  AnswerCompare const& cmp,
175  Node const* node,
176  Answers* out) const {
177  if (node == NULL) return ;
178  Scalar dist2 = distance2(vector, vectors_[node->index_]);
179  Answer my_ans(node->index_, dist2);
180  if (out->size() < k || cmp(my_ans, out->top())) {
181  out->push(my_ans);
182  if (out->size() > k) {
183  out->pop();
184  }
185  }
186  if (node->nearChild_ == NULL && node->farChild_ == NULL) return ;
187  if (out->size() < k || distanceCompare(dist2, -out->top().dist2_,
188  node->threshold_) <= 0) {
189  query(vector, k, cmp, node->nearChild_, out);
190  }
191  if (out->size() < k || distanceCompare(dist2, out->top().dist2_,
192  node->threshold_) >= 0) {
193  query(vector, k, cmp, node->farChild_, out);
194  }
195  }
196  void clear(Node* root) {
197  if(root == NULL) return ;
198  clear(root->nearChild_);
199  clear(root->farChild_);
200  delete root;
201  }
202  Node* dup(Node* root) {
203  if(root == NULL) return ;
204  Node* ret = new Node(root->index_);
205  ret->threshold_ = root->threshold_;
206  ret->nearChild_ = dup(root->nearChild_);
207  ret->farChild_ = dup(root->farChild_ );
208  return ret;
209  }
210 public:
212  VP_Tree(): root_(NULL), vectors_(0), dimension_(1), needRebuild_(false){
213  reset(0);
214  }
215 
217  VP_Tree(VP_Tree const& tree2):
218  vectors_(tree2.vectors_),
219  root_(dup(tree2.root_)),
220  dimension_(tree2.dimension_),
221  needRebuild_(tree2.needRebuild_) {
222  }
223 
225  VP_Tree(size_t dimension):
226  vectors_(0),
227  root_(NULL),
228  dimension_(0),
229  needRebuild_(false) {
230  reset(dimension);
231  }
232 
235  clear(root_);
236  }
237 
241  VP_Tree& copyFrom(VP_Tree const& tree2) {
242  reset(tree2.dimension_);
243  vectors_ = tree2.vectors_;
244  root_ = dup(tree2.root_);
245  needRebuild_ = tree2.needRebuild_;
246  return *this;
247  }
248 
252  void insert(Vector const& vector) {
253  vectors_.push_back(vector);
254  needRebuild_ = true;
255  }
256 
260  bool erase (Vector const& vector) {
261  for (ssize_t i = 0, I = vectors_.size(); i < I; i++) {
262  if (vectors_[i] == vector) {
263  if (i != I - 1) std::swap(vectors_[i], vectors_[I - 1]);
264  needRebuild_ = true;
265  vectors_.pop_back();
266  return true;
267  }
268  }
269  return false;
270  }
271 
275  void build() {
276  if (needRebuild_) {
277  forceBuild();
278  }
279  }
280 
284  void forceBuild() {
285  root_ = build(0, (size_t)vectors_.size() - 1);
286  needRebuild_ = false;
287  }
288 
296  Vectors query(Vector const& vector,
297  size_t nearestNumber,
298  bool compareWholeVector) const {
299  ((VP_Tree*)this)->build();
300  AnswerCompare cmp(&vectors_, compareWholeVector);
301  Answers answers(cmp);
302  query(vector, nearestNumber, cmp, root_, &answers);
303  std::stack<Answer> rev;
304  for ( ; !answers.empty(); answers.pop()) rev.push(answers.top());
305  Vectors ret;
306  for ( ; !rev.empty(); rev.pop()) ret.push_back(vectors_[rev.top().index_]);
307  return ret;
308  }
309 
313  void clear() {
314  clear(root_);
315  vectors_.clear();
316  root_ = NULL;
317  needRebuild_ = false;
318  }
319 
323  size_t reset(size_t dimension) {
324  clear();
325  dimension_ = std::max((size_t)1, dimension);
326  return dimension_;
327  }
328 
330  VP_Tree& operator=(VP_Tree const& tree2) {
331  return copyFrom(tree2);
332  }
333 };
334 
335 } // meow
336 
337 #endif // dsa_VP_Tree_H__
VP_Tree()
constructor, with dimension = 1
Definition: VP_Tree.h:212
void clear()
清空所有資料
Definition: VP_Tree.h:313
void insert(Vector const &vector)
將給定的Vector加到set中
Definition: VP_Tree.h:252
VP_Tree & operator=(VP_Tree const &tree2)
same as copyFrom(tree2)
Definition: VP_Tree.h:330
VP_Tree & copyFrom(VP_Tree const &tree2)
複製資料
Definition: VP_Tree.h:241
~VP_Tree()
destructor
Definition: VP_Tree.h:234
std::vector< Vector > Vectors
Definition: VP_Tree.h:53
VP_Tree(VP_Tree const &tree2)
constructor, 複製資料
Definition: VP_Tree.h:217
void build()
檢查至今是否有 insert/erase 被呼叫來決定是否 rebuild()
Definition: VP_Tree.h:275
vector
Definition: Vector.h:19
跟KD_Tree很像歐
Definition: VP_Tree.h:51
size_t reset(size_t dimension)
清空所有資料並重新給定維度
Definition: VP_Tree.h:323
VP_Tree(size_t dimension)
constructor, 給定dimension
Definition: VP_Tree.h:225
void forceBuild()
重新建樹
Definition: VP_Tree.h:284
Vectors query(Vector const &vector, size_t nearestNumber, bool compareWholeVector) const
查找
Definition: VP_Tree.h:296
bool erase(Vector const &vector)
將給定的Vector從set移除
Definition: VP_Tree.h:260
T squ(T const &x)
x*x
Definition: utility.h:67