aboutsummaryrefslogblamecommitdiffstats
path: root/meowpp/dsa/KD_Tree.hpp
blob: ac9f8682a95197b41ce28ccea869d8064b8b200c (plain) (tree)



















































































































































                                                                                 
                                                                           










                                                                        
                       



































                                                                                
#include <cstdlib>
#include <list>
#include <vector>
#include <algorithm>
#include <set>
#include "utility.h"

namespace meow{
  template<class Keys, class Key, class Value>
  inline KD_Tree<Keys,Key,Value>::Node::Node(Keys _key,
                                             Value _value,
                                             ssize_t _l_child,
                                             ssize_t _r_child):
  key(_key), value(_value), lChild(_l_child), rChild(_r_child){ }
  //
  template<class Keys, class Key, class Value>
  inline KD_Tree<Keys, Key, Value>::Sorter::Sorter(KD_Tree::Nodes const& _nodes,
                                                   size_t _cmp):
  nodes(_nodes), cmp(_cmp){ }
  template<class Keys, class Key, class Value>
  inline bool
  KD_Tree<Keys, Key, Value>::Sorter::operator()(size_t const& a,
                                                size_t const& b) const{
    if(nodes[a].key[cmp] != nodes[b].key[cmp]){
      return (nodes[a].key[cmp] < nodes[b].key[cmp]);
    }
    return (nodes[a].value < nodes[b].value);
  }
  //
  template<class Keys, class Key, class Value>
  inline KD_Tree<Keys, Key, Value>::Answer::Answer(Node const& _node,
                                                   Key _dist2):
  node(_node), dist2(_dist2){ }
  template<class Keys, class Key, class Value>
  inline bool
  KD_Tree<Keys, Key, Value>::Answer::operator<(Answer const& b) const{
    if(dist2 != b.dist2) return (dist2 < b.dist2);
    return (node.value < b.node.value);
  }
  //
  template<class Keys, class Key, class Value>
  inline Key KD_Tree<Keys, Key, Value>::distance2(Keys const& k1,
                                                  Keys const& k2) const{
    Key ret(0);
    for(size_t i = 0; i < dimension; i++)
      ret += squ(k1[i] - k2[i]);
    return ret;
  }
  template<class Keys, class Key, class Value>
  inline size_t KD_Tree<Keys, Key, Value>::query(Keys const& key,
                                                 size_t k,
                                                 size_t index,
                                                 int depth,
                                                 std::vector<Key>& dist2_v,
                                                 Key               dist2_s,
                                                 KD_Tree::AnswerList* ret) const{
    if(index == NIL){
      return 0;
    }
    size_t cmp = depth % dimension;
    ssize_t right_side, opposite, size;
    ssize_t sz, other;
    if(key[cmp] <= nodes[index].key[cmp]){
      right_side = nodes[index].lChild;
      opposite   = nodes[index].rChild;
    }else{
      right_side = nodes[index].rChild;
      opposite   = nodes[index].lChild;
    }
    size = query(key, k, right_side, depth + 1, dist2_v, dist2_s, ret);
    Answer my_ans(nodes[index], distance2(nodes[index].key, key));
    if(size < k || my_ans < *(ret->rbegin())){
      KD_Tree::AnswerListIterator it = ret->begin();
      while(it != ret->end() && !(my_ans < *it)) it++;
      ret->insert(it, my_ans);
      size++;
    }
    Key dist2_old = dist2_v[cmp];
    dist2_v[cmp] = squ(nodes[index].key[cmp] - key[cmp]);
    dist2_s     += dist2_v[cmp] - dist2_old;
    if(size < k || (*(ret->rbegin())).dist2 >= dist2_s){
      KD_Tree::AnswerList ret2;
      size += query(key, k, opposite, depth + 1, dist2_v, dist2_s, &ret2);
      KD_Tree::AnswerListIterator it1, it2;
      for(it1 = ret->begin(), it2 = ret2.begin(); it2 != ret2.end(); it2++){
        while(it1 != ret->end() && *it1 < *it2) it1++;
        it1 = ++(ret->insert(it1, *it2));
      }
    }
    if(size > k){
      for(int i = size - k; i--; ){
        ret->pop_back();
      }
      size = k;
    }
    dist2_v[cmp] = dist2_old;
    return size;
  }
  template<class Keys, class Key, class Value>
  inline ssize_t KD_Tree<Keys, Key, Value>::build(ssize_t beg,
                                                  ssize_t end,
                                                  std::vector<size_t>* orders,
                                                  int depth){
    if(beg > end){
      return NIL;
    }
    ssize_t mid = (beg + end) / 2;
    size_t  cmp = depth % 2;
    std::set<size_t> right;
    for(ssize_t i = mid + 1; i <= end; i++){
      right.insert(orders[cmp][i]);
    }
    for(int i = 0; i < dimension; i++){
      if(i == cmp) continue;
      size_t aa = beg, bb = mid + 1;
      for(int j = beg; j <= end; j++){
        if(orders[i][j] == orders[cmp][mid]){
          orders[dimension][mid] = orders[i][j];
        }else if(right.find(orders[i][j]) != right.end()){
          orders[dimension][bb++] = orders[i][j];
        }else{
          orders[dimension][aa++] = orders[i][j];
        }
      }
      for(int j = beg; j <= end; j++){
        orders[i][j] = orders[dimension][j];
      }
    }
    nodes[orders[cmp][mid]].lChild = build(beg, mid - 1, orders, depth + 1);
    nodes[orders[cmp][mid]].rChild = build(mid + 1, end, orders, depth + 1);
    return orders[cmp][mid];
  }
  template<class Keys, class Key, class Value>
  inline KD_Tree<Keys, Key, Value>::KD_Tree():
  NIL(-1), root(NIL), needRebuild(false), dimension(1){ }
  template<class Keys, class Key, class Value>
  inline KD_Tree<Keys, Key, Value>::KD_Tree(size_t _dimension):
  NIL(-1), root(NIL), needRebuild(false), dimension(_dimension){ }
  template<class Keys, class Key, class Value>
  inline KD_Tree<Keys, Key, Value>::~KD_Tree(){ }
  template<class Keys, class Key, class Value>
  inline void KD_Tree<Keys, Key, Value>::insert(Keys const& key, Value value){
    nodes.push_back(Node(key, value, NIL, NIL));
    needRebuild = true;
  }
  template<class Keys, class Key, class Value>
  inline void KD_Tree<Keys, Key, Value>::build(){
    if(needRebuild){
      std::vector<size_t> *orders = new std::vector<size_t>[dimension + 1];
      for(int j = 0; j < dimension + 1; j++){
        orders[j].resize(nodes.size());
      }
      for(int j = 0; j < dimension; j++){
        for(size_t i = 0, I = nodes.size(); i < I; i++){
          orders[j][i] = i;
        }
        std::sort(orders[j].begin(), orders[j].end(), Sorter(nodes, j));
      }
      root = build(0, (ssize_t)nodes.size() - 1, orders, 0);
      needRebuild = false;
      delete [] orders;
    }
  }
  template<class Keys, class Key, class Value>
  inline Value KD_Tree<Keys, Key, Value>::query(Keys const& point, int k) const{
    ((KD_Tree*)this)->build();
    KD_Tree::AnswerList ret;
    std::vector<Key> tmp(dimension, Key(0));
    query(point, k, root, 0, tmp, Key(0), &ret);
    return (*(ret.rbegin())).node.value;
  }
  template<class Keys, class Key, class Value>
  inline typename KD_Tree<Keys, Key, Value>::Values
  KD_Tree<Keys, Key, Value>::rangeQuery(Keys const& point, int k) const{
    ((KD_Tree*)this)->build();
    KD_Tree::AnswerList ret;
    std::vector<Key> tmp(dimension, Key(0));
    query(point, k, root, 0, tmp, Key(0), &ret);
    KD_Tree::Values ret_val(ret.size());
    int i = 0;
    for(KD_Tree::AnswerListIterator it = ret.begin(); it != ret.end(); it++){
      ret_val[i++] = (*it).node.value;
    }
    return ret_val;
  }
  template<class Keys, class Key, class Value>
  inline void KD_Tree<Keys, Key, Value>::clear(){
    root = NIL;
    nodes.clear();
    needRebuild = false;
  }
  template<class Keys, class Key, class Value>
  inline void KD_Tree<Keys, Key, Value>::reset(size_t _dimension){
    clear();
    dimension = _dimension;
  }
}