1 #ifndef dsa_VP_Tree_H__
2 #define dsa_VP_Tree_H__
4 #include "../math/utility.h"
50 template<
class Vector,
class Scalar>
61 Node(
size_t index): index_(index), nearChild_(NULL), farChild_(NULL){
68 Answer(
size_t index, Scalar
const& dist2): index_(index), dist2_(dist2){
70 Answer(Answer
const& answer2):
71 index_(answer2.index_), dist2_(answer2.dist2_){
79 AnswerCompare(
Vectors const* vectors,
bool cmpValue):
80 vectors_(vectors), cmpValue_(cmpValue){
82 bool operator()(Answer
const& a, Answer
const& b)
const {
83 if (a.dist2_ < b.dist2_)
return true;
84 if (b.dist2_ < a.dist2_)
return false;
85 return (cmpValue_ && ((*vectors_)[a.index_] < (*vectors_)[b.index_]));
88 typedef std::vector<Answer> AnswerV;
89 typedef std::priority_queue<Answer, AnswerV, AnswerCompare> Answers;
96 Scalar distance2(Vector
const& v1, Vector
const& v2)
const {
98 for (
size_t i = 0; i < dimension_; i++) ret +=
squ(v1[i] - v2[i]);
101 int distanceCompare(Scalar
const& a2, Scalar
const& b2,
102 Scalar
const& c2)
const {
104 return -distanceCompare(c2, -b2, a2);
106 Scalar cab(c2 - a2 - b2);
107 if (cab < Scalar(0))
return 1;
108 Scalar ab2(Scalar(4) * a2 * b2), cab2(
squ(cab));
109 if ( ab2 < cab2)
return -1;
110 else if (cab2 < ab2)
return 1;
113 Scalar split(ssize_t first, ssize_t last,
size_t order,
114 Vector
const& center) {
115 ssize_t first0 = first;
116 std::vector<Scalar> dist2(last - first + 1);
117 for (ssize_t i = first; i <= last; i++) {
118 dist2[i - first0] = distance2(vectors_[i], center);
120 while (first < last) {
121 size_t thresholdindex_ = first + rand() % (last - first + 1);
122 Scalar threshold(dist2[thresholdindex_ - first0]);
123 size_t large_first = last + 1;
124 for( ssize_t i=first; first<=(ssize_t)large_first-1; large_first--) {
125 if (threshold < dist2[large_first - 1 - first0])
continue;
126 while (i < (ssize_t)large_first-1&&!(threshold < dist2[i-first0])) i++;
127 if (i < (ssize_t)large_first - 1){
128 std::swap(dist2 [large_first - 1 - first0], dist2 [i - first0]);
129 std::swap(vectors_[large_first - 1 ], vectors_[i ]);
136 if (large_first == (
size_t)last + 1) {
137 std::swap(dist2 [thresholdindex_-first0], dist2 [last-first0]);
138 std::swap(vectors_[thresholdindex_ ], vectors_[last ]);
139 if ((ssize_t)order == last - first) {
146 if (order < large_first - first) {
147 last = large_first - 1;
150 order -= large_first - first;
155 return dist2[first - first0];
158 Node*
build(ssize_t first, ssize_t last) {
159 if (first > last)
return NULL;
160 Node* ret =
new Node(first);
162 std::swap(vectors_[first],
163 vectors_[first + rand() % (last - first + 1)]);
164 ssize_t mid = (first + 1 + last + 1) / 2;
165 ret->threshold_ = split(first + 1, last, mid - (first + 1),
167 ret->nearChild_ =
build(first + 1, mid - 1 );
168 ret->farChild_ =
build( mid , last);
172 void query(Vector
const& vector,
174 AnswerCompare
const& cmp,
176 Answers* out)
const {
177 if (node == NULL) return ;
178 Scalar dist2 = distance2(vector, vectors_[node->index_]);
179 Answer my_ans(node->index_, dist2);
180 if (out->size() < k || cmp(my_ans, out->top())) {
182 if (out->size() > k) {
186 if (node->nearChild_ == NULL && node->farChild_ == NULL) return ;
187 if (out->size() < k || distanceCompare(dist2, -out->top().dist2_,
188 node->threshold_) <= 0) {
189 query(vector, k, cmp, node->nearChild_, out);
191 if (out->size() < k || distanceCompare(dist2, out->top().dist2_,
192 node->threshold_) >= 0) {
193 query(vector, k, cmp, node->farChild_, out);
196 void clear(Node* root) {
197 if(root == NULL) return ;
198 clear(root->nearChild_);
199 clear(root->farChild_);
202 Node* dup(Node* root) {
203 if(root == NULL) return ;
204 Node* ret =
new Node(root->index_);
205 ret->threshold_ = root->threshold_;
206 ret->nearChild_ = dup(root->nearChild_);
207 ret->farChild_ = dup(root->farChild_ );
212 VP_Tree(): root_(NULL), vectors_(0), dimension_(1), needRebuild_(false){
218 vectors_(tree2.vectors_),
219 root_(dup(tree2.root_)),
220 dimension_(tree2.dimension_),
221 needRebuild_(tree2.needRebuild_) {
229 needRebuild_(false) {
242 reset(tree2.dimension_);
243 vectors_ = tree2.vectors_;
244 root_ = dup(tree2.root_);
245 needRebuild_ = tree2.needRebuild_;
253 vectors_.push_back(vector);
261 for (ssize_t i = 0, I = vectors_.size(); i < I; i++) {
262 if (vectors_[i] == vector) {
263 if (i != I - 1) std::swap(vectors_[i], vectors_[I - 1]);
285 root_ =
build(0, (
size_t)vectors_.size() - 1);
286 needRebuild_ =
false;
297 size_t nearestNumber,
298 bool compareWholeVector)
const {
300 AnswerCompare cmp(&vectors_, compareWholeVector);
301 Answers answers(cmp);
302 query(vector, nearestNumber, cmp, root_, &answers);
303 std::stack<Answer> rev;
304 for ( ; !answers.empty(); answers.pop()) rev.push(answers.top());
306 for ( ; !rev.empty(); rev.pop()) ret.push_back(vectors_[rev.top().index_]);
317 needRebuild_ =
false;
325 dimension_ = std::max((
size_t)1, dimension);
337 #endif // dsa_VP_Tree_H__
VP_Tree()
constructor, with dimension = 1
void insert(Vector const &vector)
將給定的Vector加到set中
VP_Tree & operator=(VP_Tree const &tree2)
same as copyFrom(tree2)
VP_Tree & copyFrom(VP_Tree const &tree2)
複製資料
std::vector< Vector > Vectors
VP_Tree(VP_Tree const &tree2)
constructor, 複製資料
void build()
檢查至今是否有 insert/erase 被呼叫來決定是否 rebuild()
size_t reset(size_t dimension)
清空所有資料並重新給定維度
VP_Tree(size_t dimension)
constructor, 給定dimension
Vectors query(Vector const &vector, size_t nearestNumber, bool compareWholeVector) const
查找
bool erase(Vector const &vector)
將給定的Vector從set移除