#ifndef dsa_VP_Tree_H__ #define dsa_VP_Tree_H__ #include "../math/utility.h" #include #include #include #include #include namespace meow { /*! * @brief 跟KD_Tree很像歐 * * \c VP_Tree 用來維護由 \b N個K維度向量所成的集合 , * 並可於該set中查找 \b 前i個離給定向量最接近的向量* . * 不像 \c KD_Tree 二分樹每次都選擇一個維度去分, 分成小的跟大的, * \c VP_Tree 每次選一個點, 將資料分成 離這個點近的, 跟離這個點遠的. * 至於怎麼選呢...., 嘛還沒研究, 先random * * 參考資料連結: * - http://stevehanov.ca/blog/index.php?id=130 * - http://pnylab.com/pny/papers/vptree/vptree * * Template Class Operators Request * -------------------------------- * * |const?|Typename|Operator | Parameters |Return Type | Description | * |-----:|:------:|----------:|:-------------|:----------:|:------------------| * |const | Vector|operator[] |(size_t \c n) | Scalar | 取得第\c n 維度量 | * |const | Vector|operator= |(Vector \c v) | Vector& | copy operator | * |const | Vector|operator< |(Vector \c v) | bool | 權重比較 | * |const | Scalar| 'Scalar' |(int \c n) | Scalar | 建構子, * 其中一定\c n=0or4 | * |const | Scalar|operator* |(Scalar \c s) | Scalar | 相乘 | * |const | Scalar|operator+ |(Scalar \c s) | Scalar | 相加 | * |const | Scalar|operator- |(Scalar \c s) | Scalar | 相差 | * |const | Scalar|operator- |( ) | Scalar | 取負號 | * |const | Scalar|operator< |(Scalar \c s) | bool | 大小比較 | * * @note: * -實測結果發覺, 維度小的時候, 比起中規中矩的 \c KD_Tree, \c VP_Tree 有 * \b random 於其中, 因此時間複雜度只是期望值 \c O(logN) 但是測資大到 * 一定程度, \c KD_Tree 效率會一整個大幅掉下, 但 \c VP_Tree 幾乎不受影響 * -TODO \c insert(), \c erase() 算是未完成功能 */ template class VP_Tree { public: typedef std::vector Vectors; private: struct Node { size_t index_; Scalar threshold_; Node* nearChild_; Node* farChild_; // Node(size_t index): index_(index), nearChild_(NULL), farChild_(NULL){ } }; struct Answer { size_t index_; Scalar dist2_; // Answer(size_t index, Scalar const& dist2): index_(index), dist2_(dist2){ } Answer(Answer const& answer2): index_(answer2.index_), dist2_(answer2.dist2_){ } }; class AnswerCompare { private: Vectors const* vectors_; bool cmpValue_; public: AnswerCompare(Vectors const* vectors, bool cmpValue): vectors_(vectors), cmpValue_(cmpValue){ } bool operator()(Answer const& a, Answer const& b) const { if (a.dist2_ < b.dist2_) return true; if (b.dist2_ < a.dist2_) return false; return (cmpValue_ && ((*vectors_)[a.index_] < (*vectors_)[b.index_])); } }; typedef std::vector AnswerV; typedef std::priority_queue Answers; Vectors vectors_; Node* root_; size_t dimension_; bool needRebuild_; Scalar distance2(Vector const& v1, Vector const& v2) const { Scalar ret(0); for (size_t i = 0; i < dimension_; i++) ret += squ(v1[i] - v2[i]); return ret; } int distanceCompare(Scalar const& a2, Scalar const& b2, Scalar const& c2) const { if (b2 < 0) { return -distanceCompare(c2, -b2, a2); } Scalar cab(c2 - a2 - b2); if (cab < Scalar(0)) return 1; Scalar ab2(Scalar(4) * a2 * b2), cab2(squ(cab)); if ( ab2 < cab2) return -1; else if (cab2 < ab2) return 1; else return 0; } Scalar split(ssize_t first, ssize_t last, size_t order, Vector const& center) { ssize_t first0 = first; std::vector dist2(last - first + 1); for (ssize_t i = first; i <= last; i++) { dist2[i - first0] = distance2(vectors_[i], center); } while (first < last) { size_t thresholdindex_ = first + rand() % (last - first + 1); Scalar threshold(dist2[thresholdindex_ - first0]); size_t large_first = last + 1; for( ssize_t i=first; first<=(ssize_t)large_first-1; large_first--) { if (threshold < dist2[large_first - 1 - first0]) continue; while (i < (ssize_t)large_first-1&&!(threshold < dist2[i-first0])) i++; if (i < (ssize_t)large_first - 1){ std::swap(dist2 [large_first - 1 - first0], dist2 [i - first0]); std::swap(vectors_[large_first - 1 ], vectors_[i ]); i++; } else { break; } } if (large_first == (size_t)last + 1) { std::swap(dist2 [thresholdindex_-first0], dist2 [last-first0]); std::swap(vectors_[thresholdindex_ ], vectors_[last ]); if ((ssize_t)order == last - first) { first = last; break; } last--; } else { if (order < large_first - first) { last = large_first - 1; } else { order -= large_first - first; first = large_first; } } } return dist2[first - first0]; } // Node* build(ssize_t first, ssize_t last) { if (first > last) return NULL; Node* ret = new Node(first); if (first < last) { std::swap(vectors_[first], vectors_[first + rand() % (last - first + 1)]); ssize_t mid = (first + 1 + last + 1) / 2; ret->threshold_ = split(first + 1, last, mid - (first + 1), vectors_[first]); ret->nearChild_ = build(first + 1, mid - 1 ); ret->farChild_ = build( mid , last); } return ret; } void query(Vector const& vector, size_t k, AnswerCompare const& cmp, Node const* node, Answers* out) const { if (node == NULL) return ; Scalar dist2 = distance2(vector, vectors_[node->index_]); Answer my_ans(node->index_, dist2); if (out->size() < k || cmp(my_ans, out->top())) { out->push(my_ans); if (out->size() > k) { out->pop(); } } if (node->nearChild_ == NULL && node->farChild_ == NULL) return ; if (out->size() < k || distanceCompare(dist2, -out->top().dist2_, node->threshold_) <= 0) { query(vector, k, cmp, node->nearChild_, out); } if (out->size() < k || distanceCompare(dist2, out->top().dist2_, node->threshold_) >= 0) { query(vector, k, cmp, node->farChild_, out); } } void clear(Node* root) { if(root == NULL) return ; clear(root->nearChild_); clear(root->farChild_); delete root; } Node* dup(Node* root) { if(root == NULL) return ; Node* ret = new Node(root->index_); ret->threshold_ = root->threshold_; ret->nearChild_ = dup(root->nearChild_); ret->farChild_ = dup(root->farChild_ ); return ret; } public: //! @brief constructor, with dimension = 1 VP_Tree(): root_(NULL), vectors_(0), dimension_(1), needRebuild_(false){ reset(0); } //! @brief constructor, 複製資料 VP_Tree(VP_Tree const& tree2): vectors_(tree2.vectors_), root_(dup(tree2.root_)), dimension_(tree2.dimension_), needRebuild_(tree2.needRebuild_) { } //! @brief constructor, 給定dimension VP_Tree(size_t dimension): vectors_(0), root_(NULL), dimension_(0), needRebuild_(false) { reset(dimension); } //! @brief destructor ~VP_Tree() { clear(root_); } /*! * @brief 複製資料 */ VP_Tree& copyFrom(VP_Tree const& tree2) { reset(tree2.dimension_); vectors_ = tree2.vectors_; root_ = dup(tree2.root_); needRebuild_ = tree2.needRebuild_; return *this; } /*! * @brief 將給定的Vector加到set中 */ void insert(Vector const& vector) { vectors_.push_back(vector); needRebuild_ = true; } /*! * @brief 將給定的Vector從set移除 */ bool erase (Vector const& vector) { for (ssize_t i = 0, I = vectors_.size(); i < I; i++) { if (vectors_[i] == vector) { if (i != I - 1) std::swap(vectors_[i], vectors_[I - 1]); needRebuild_ = true; vectors_.pop_back(); return true; } } return false; } /*! * @brief 檢查至今是否有 insert/erase 被呼叫來決定是否 \c rebuild() */ void build() { if (needRebuild_) { forceBuild(); } } /*! * @brief 重新建樹 */ void forceBuild() { root_ = build(0, (size_t)vectors_.size() - 1); needRebuild_ = false; } /*! * @brief 查找 * * 於set中找尋距離指定向量前 \c i 近的向量, 並依照由近而遠的順序排序. * 如果有兩個向量\c v1,v2 距離一樣, 且 \c cmp 為\c true , 則直接依照 * \c v1build(); AnswerCompare cmp(&vectors_, compareWholeVector); Answers answers(cmp); query(vector, nearestNumber, cmp, root_, &answers); std::stack rev; for ( ; !answers.empty(); answers.pop()) rev.push(answers.top()); Vectors ret; for ( ; !rev.empty(); rev.pop()) ret.push_back(vectors_[rev.top().index_]); return ret; } /*! * @brief 清空所有資料 */ void clear() { clear(root_); vectors_.clear(); root_ = NULL; needRebuild_ = false; } /*! * @brief 清空所有資料並重新給定維度 */ size_t reset(size_t dimension) { clear(); dimension_ = std::max((size_t)1, dimension); return dimension_; } //! @brief same as \c copyFrom(tree2) VP_Tree& operator=(VP_Tree const& tree2) { return copyFrom(tree2); } }; } #endif // dsa_VP_Tree_H__