#ifndef dsa_KD_Tree_H__ #define dsa_KD_Tree_H__ #include "../utility.h" #include "../math/utility.h" #include #include #include #include namespace meow { /*! * @brief \c k-dimension tree * * 全名k-dimension tree, 用來維護由\b N個K維度向量所成的集合, * 並可於該set中查找 \b 前i個離給定向量最接近的向量 * * Template Class Operators Request * -------------------------------- * * |const?|Typename|Operator | Parameters |Return Type | Description | * |-----:|:------:|----------:|:-------------|:----------:|:------------------| * |const |Vector |operator[] |(size_t \c n) |Scalar | 取得第 `n` 維度量 | * |const |Vector |operator< |(Vector \c v) |bool | 權重比較 | * |const |Scalar |operator* |(Scalar \c s) |Scalar | 相乘 | * |const |Scalar |operator+ |(Scalar \c s) |Scalar | 相加 | * |const |Scalar |operator- |(Scalar \c s) |Scalar | 相差 | * |const |Scalar |operator< |(Scalar \c s) |bool | 大小比較 | * * @note: * 此資料結構只有在 N >> 2 K 時才比較有優勢, * 當 K 逐漸變大時, 所花時間會跟暴搜沒兩樣 * * @author cat_leopard */ template class KD_Tree { private: struct Node { Vector vector_; ssize_t lChild_; ssize_t rChild_; Node(Vector v, ssize_t l, ssize_t r): vector_(v), lChild_(l), rChild_(r){ } }; typedef std::vector Nodes; class Sorter { private: Nodes const* nodes_; size_t cmp_; public: Sorter(Nodes const* nodes, size_t cmp): nodes_(nodes), cmp_(cmp){ } bool operator()(size_t const& a, size_t const& b) const{ if ((*nodes_)[a].vector_[cmp_] != (*nodes_)[b].vector_[cmp_]) { return ((*nodes_)[a].vector_[cmp_] < (*nodes_)[b].vector_[cmp_]); } return ((*nodes_)[a].vector_ < (*nodes_)[b].vector_); } }; struct Answer { ssize_t index_; Scalar dist2_; // Answer(ssize_t index, Scalar dist2): index_(index), dist2_(dist2) { } Answer(Answer const& answer2): index_(answer2.index_), dist2_(answer2.dist2_) { } }; class AnswerCompare { private: Nodes const* nodes_; bool cmpValue_; public: AnswerCompare(Nodes const* nodes, bool cmpValue): nodes_(nodes), cmpValue_(cmpValue) { } bool operator()(Answer const& a, Answer const& b) const { if (cmpValue_ == true && a.dist2_ == b.dist2_) { return ((*nodes_)[a.index_].vector_ < (*nodes_)[b.index_].vector_); } return (a.dist2_ < b.dist2_); } }; typedef std::vector AnswerV; typedef std::priority_queue Answers; // const ssize_t kNIL_; // Nodes nodes_; size_t root_; bool needRebuild_; size_t dimension_; // Scalar distance2(Vector const& v1, Vector const& v2) const { Scalar ret(0); for(size_t i = 0; i < dimension_; i++){ ret += squ(v1[i] - v2[i]); } return ret; } // void query(Vector const& v, size_t nearestNumber, AnswerCompare const& answerCompare, ssize_t index, int depth, std::vector& dist2Vector, Scalar dist2Minimum, Answers *out) const { if (index == kNIL_) return ; size_t cmp = depth % dimension_; ssize_t this_side, that_side; if (!(nodes_[index].vector_[cmp] < v[cmp])) { this_side = nodes_[index].lChild_; that_side = nodes_[index].rChild_; }else{ this_side = nodes_[index].rChild_; that_side = nodes_[index].lChild_; } query(v, nearestNumber, answerCompare, this_side, depth + 1, dist2Vector, dist2Minimum, out); Answer my_ans(index, distance2(nodes_[index].vector_, v)); if (out->size() < nearestNumber || answerCompare(my_ans, out->top())) { out->push(my_ans); if (out->size() > nearestNumber) out->pop(); } Scalar dist2_old(dist2Vector[cmp]); dist2Vector[cmp] = squ(nodes_[index].vector_[cmp] - v[cmp]); Scalar dist2Minimum2(dist2Minimum + dist2Vector[cmp] - dist2_old); if (out->size() < nearestNumber || !(out->top().dist2_ < dist2Minimum)) { query(v, nearestNumber, answerCompare, that_side, depth + 1, dist2Vector, dist2Minimum2, out); } dist2Vector[cmp] = dist2_old; } ssize_t build(ssize_t beg, ssize_t end, std::vector* orders, int depth) { if (beg > end) return kNIL_; size_t tmp_order = dimension_; size_t which_side = dimension_ + 1; ssize_t mid = (beg + end) / 2; size_t cmp = depth % dimension_; for (ssize_t i = beg; i <= mid; i++) { orders[which_side][orders[cmp][i]] = 0; } for (ssize_t i = mid + 1; i <= end; i++) { orders[which_side][orders[cmp][i]] = 1; } for (size_t i = 0; i < dimension_; i++) { if (i == cmp) continue; size_t left = beg, right = mid + 1; for (int j = beg; j <= end; j++) { size_t ask = orders[i][j]; if(ask == orders[cmp][mid]) { orders[tmp_order][mid] = ask; } else if(orders[which_side][ask] == 1) { orders[tmp_order][right++] = ask; } else { orders[tmp_order][left++] = ask; } } for (int j = beg; j <= end; j++) { orders[i][j] = orders[tmp_order][j]; } } nodes_[orders[cmp][mid]].lChild_ = build(beg, mid - 1, orders, depth + 1); nodes_[orders[cmp][mid]].rChild_ = build(mid + 1, end, orders, depth + 1); return orders[cmp][mid]; } public: //! Custom Type: Vectors is \c std::vector typedef typename std::vector Vectors; //! @brief constructor, with dimension = 1 KD_Tree(): kNIL_(-1), root_(kNIL_), needRebuild_(false), dimension_(1) { } //! @brief constructor, given dimension KD_Tree(size_t dimension): kNIL_(-1), root_(kNIL_), needRebuild_(false), dimension_(dimension) { } //! @brief destructor ~KD_Tree() { } /*! * @brief 將給定的Vector加到set中 */ void insert(Vector const& v) { nodes_.push_back(Node(v, kNIL_, kNIL_)); needRebuild_ = true; } /*! * @brief 將給定的Vector從set移除 */ bool erase(Vector const& v) { for (size_t i = 0, I = nodes_.size(); i < I; i++) { if (nodes_[i] == v) { if (i != I - 1) { std::swap(nodes_[i], nodes_[I - 1]); } needRebuild_ = true; return true; } } return false; } /*! * @brief 檢查至今是否有 insert/erase 被呼叫來決定是否 \c rebuild() */ void build(){ if (needRebuild_) { forceBuild(); } } /*! * @brief 重新建樹 */ void forceBuild() { std::vector *orders = new std::vector[dimension_ + 2]; for (size_t j = 0; j < dimension_ + 2; j++) { orders[j].resize(nodes_.size()); } for (size_t j = 0; j < dimension_; j++) { for (size_t i = 0, I = nodes_.size(); i < I; i++) { orders[j][i] = i; } std::sort(orders[j].begin(), orders[j].end(), Sorter(&nodes_, j)); } root_ = build(0, (ssize_t)nodes_.size() - 1, orders, 0); delete [] orders; needRebuild_ = false; } /*! * @brief 查找 * * 於set中找尋距離指定向量前 \c i 近的向量, 並依照由近而遠的順序排序. * 如果有兩個向量\c v1,v2 距離一樣, 且 \c cmp 為\c true , 則直接依照 * \c v1build(); AnswerCompare answer_compare(&nodes_, compareWholeVector); Answers answer_set(answer_compare); std::vector tmp(dimension_, 0); query(v, nearestNumber, answer_compare, root_, 0, tmp, Scalar(0), &answer_set); Vectors ret(answer_set.size()); for (int i = (ssize_t)answer_set.size() - 1; i >= 0; i--) { ret[i] = nodes_[answer_set.top().index_].vector_; answer_set.pop(); } return ret; } /*! * @brief 清空所有資料 */ void clear() { root_ = kNIL_; nodes_.clear(); needRebuild_ = false; } /*! * @brief 清空所有資料並重新給定維度 */ void reset(size_t dimension) { clear(); dimension_ = dimension; } }; } #endif // dsa_KD_Tree_H__