aboutsummaryrefslogtreecommitdiffstats
path: root/meowpp/dsa/SegmentTree.h
blob: 305c4c3b8e1adac9520ff0612a29a9c4f636699c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#ifndef   dsa_SegmentTree_H__
#define   dsa_SegmentTree_H__

#include "../math/utility.h"

#include <vector>
#include <algorithm>

#include <cstdlib>

namespace meow {
/*!
 * @brief 中文名 \c 線段樹
 *
 * 維護一個陣列, 並且讓user可以有區間查詢, 區間修改的小東東
 *
 * Template Class Operators Request
 * --------------------------------
 *
 * |const?|Typename|Operator   | Parameters   |Return Type | Description       |
 * |-----:|:------:|----------:|:-------------|:----------:|:------------------|
 * |const |Vector  |operator[] |(size_t \c n) |Scalar      | 取得第 `n` 維度量 |
 * |const |Vector  |operator<  |(Vector \c v) |bool        | 權重比較          |
 * |const |Scalar  |operator*  |(Scalar \c s) |Scalar      | 相乘              |
 * |const |Scalar  |operator+  |(Scalar \c s) |Scalar      | 相加              |
 * |const |Scalar  |operator-  |(Scalar \c s) |Scalar      | 相差              |
 * |const |Scalar  |operator<  |(Scalar \c s) |bool        | 大小比較          |
 * |const |Value   |operator+  |(Value  \c v) |Value       | 相加(位移)        |
 * |const |Value   |operator*  |(size_t \c n) |Value       | 每個Value都一樣,
 *                                                          長為 `n` 的區間的值|
 * |const |Value   |operator{b}|(Value  \c v) |Value       | 區間合併後的值    |
 *
 * - 若要維護區間最小值, 即每次都是詢問範圍 `[a, b]` 的最小值, 則可以定義
 *      - \c operator+ 為 '回傳相加值'
 *      - \c operator* 為 '回傳*this'
 *      - \c operator| 為 '回傳std::min(*this, v)'
 * - 若要維護區間最總和, 即每次都是詢問範圍 `[a, b]` 的總和, 則可以定義
 *      - \c operator+ 為 '回傳相加值'
 *      - \c operator* 為 '回傳(*this) * n'
 *      - \c operator| 為 '回傳相加值'
 *
 * @author cat_leopard
 */
template<class Value>
class SegmentTree {
private:
  struct Node {
    Value     value_;
    Value    offset_;
    bool  sameFlage_;
  };
  //
  size_t             size_;
  std::vector<Node> nodes_;
  //
  void update(size_t index, size_t size, Value const& value, bool override) {
    if (override) {
      nodes_[index].value_    = value * size;
      nodes_[index].offset_   = value;
      nodes_[index].sameFlage_ = true;
    }
    else {
      nodes_[index].value_  = nodes_[index].value_  + value * size;
      nodes_[index].offset_ = nodes_[index].offset_ + value;
    }
  }
  void update(size_t l, size_t r, size_t L, size_t R,
              size_t index, Value const& value,
              bool override) {
    if (l == L && r == R) {
      update(index, R - L + 1, value, override);
      return ;
    }
    size_t mid = (L + R) / 2;
    if (L < R) {
      update(index * 2 + 1, mid - L + 1,
             nodes_[index].offset_, nodes_[index].sameFlage_);
      update(index * 2 + 2, R - mid,
             nodes_[index].offset_, nodes_[index].sameFlage_);
      nodes_[index].offset_ = Value(0);
      nodes_[index].sameFlage_ = false;
    }
    if (r <= mid) {
      update(l, r, L ,mid, index * 2 + 1, value, override);
    }
    else if (mid + 1 <= l) {
      update(l, r, mid + 1,R, index*2 + 2, value, override);
    }
    else {
      update(l, mid       , L, mid       , index * 2 + 1, value, override);
      update(   mid + 1, r,    mid + 1, R, index * 2 + 2, value, override);
    }
    nodes_[index].value_ = (
      (nodes_[index * 2 + 1].value_ | nodes_[index * 2 + 2].value_)
      + nodes_[index].offset_
    );
  }
  Value query(size_t l, size_t r, size_t L, size_t R, size_t index) {
    if (l == L && r == R) return nodes_[index].value_;
    Value off = nodes_[index].offset_ * (r - l + 1);
    if (nodes_[index].sameFlage_) return off;
    size_t mid = (L + R) / 2;
    if     (r       <= mid) return query(l, r, L   ,  mid, index * 2 + 1) + off;
    else if(mid + 1 <=   l) return query(l, r, mid + 1, R, index * 2 + 2) + off;
    else{
      return (  query(l, mid       , L,  mid       , index * 2 + 1)
              | query(   mid + 1, r,     mid + 1, R, index * 2 + 2)
      ) + off;
    }
  }
  //
  bool rangeCorrect(ssize_t* first, ssize_t* last) const {
    if (*last < *first || *last < 0 || (ssize_t)size_ - 1 < *first)
      return false;
    *first = inRange((ssize_t)0, (ssize_t)size_ - 1, *first);
    *last  = inRange((ssize_t)0, (ssize_t)size_ - 1, *last );
    return true;
  }
public:
  //! @brief constructor
  SegmentTree() {
    reset(1);
  }

  //! @brief constructor, with \c size gived
  SegmentTree(size_t size) {
    reset(size);
  }

  //! @brief constructor, 並且複製資料
  SegmentTree(SegmentTree const& tree2):
  size_(tree2.size_), nodes_(tree2.nodes_) {
  }

  /*!
   * @brief 複製
   */
  SegmentTree copyFrom(SegmentTree const& b) {
    size_  = b.size_;
    nodes_ = b.nodes_;
    return *this;
  }

  /*!
   * @brief 回傳size
   */
  size_t size() const {
    return size_;
  }

  /*!
   * @brief 將資料清空且設定維護範圍是 \c 0~size-1
   */
  void reset(size_t size){
    size_ = std::max(size, (size_t)1);
    nodes_.resize(size * 4);
    nodes_[0].sameFlage_ = true;
    nodes_[0].value_  = Value(0);
    nodes_[0].offset_ = Value(0);
  }

  /*!
   * @brief 回傳區間 \c [first,last] (邊界都含) 的區間值
   */
  Value query(ssize_t first, ssize_t last) const {
    if (rangeCorrect(&first, &last) == false) return Value();
    return ((SegmentTree*)this)->query(first, last, 0, size_ - 1, 0);
  }

  /*!
   * @brief 將區間 \c [first,last] 全部都設定成 \c value
   */
  void override(ssize_t first, ssize_t last, Value const& value) {
    if (rangeCorrect(&first, &last) == false) return ;
    update(first, last, 0, size_ - 1, 0, value, true);
  }

  /*!
   * @brief 將區間 \c [first,last] 全部都加上 \c delta
   */
  void offset(ssize_t first, ssize_t last, Value const& delta) {
    if (rangeCorrect(&first, &last) == false) return ;
    update(first, last, 0, size_ - 1, 0, delta, false);
  }

  //! @brief same as copyFrom(b)
  SegmentTree& operator=(SegmentTree const& b) {
    return copyFrom(b);
  }
};

} // meow

#endif // dsa_SegmentTree_H__