C++

セグメント木ライブラリを整備する(3)

ISO_C++_Logo

2回に分けて、セグメント木を使うプログラムを紹介しました。今回は、セグメント木クラスを汎用的に使えるようにします。

セグメント木

1点更新と範囲の最小値を求めるプログラム1点更新と範囲の和を求めるプログラムを紹介しました。

そこで紹介したセグメント木のクラスは、可搬性がなくそれぞれの問題を解くために固定していました。コードの差分を比較すると、以下しか差分がありません。

  • 演算の関数:最小値を求める場合はminを適用しました。和を求める場合は、演算子+で加えました。
  • 初期値:最小値を求める場合は非常に大きな値を設定しました。和を求める場合は0を設定しました。

一般的に以下の汎用化ができます。

  • セグメント木は、以下の2つを与えれば、1点更新と区間の要素に演算を行った結果を得ることができる。
    • 演算をする関数 $op$:ただし、結合律 $(a\; op\; b)\; op\; c = a\; op\; (b\; op\; c)$ を満たす必要がある。
    • 単位元:任意の $a$ に対して、$a\; op\; e = a$ となるような単位元 $e$

上記の $op$ と $e$ をコンストラクタで与えるようにしたコードは以下となります。なお、値が取得できるメソッド get も加えました。

class SegmentTree {
private:
	vector<int> data;
	int (*op)(int, int);
	int e;

public:
	int size;

	SegmentTree(int n, 	int (*_op)(int, int), int _e) {
		op = _op;
		e = _e;
		size = 1;
		while (size < n) {
			size *= 2;
		}
		data.assign(2 * size - 1, e);
	}

	int get(int k) {
		return data[k + size - 1];
	}

	void update(int k, int a) {
		k += size - 1;
		data[k] = a;
		while (k > 0) {
			k = (k - 1) / 2;
			data[k] = op(data[k * 2 + 1], data[k * 2 + 2]);
		}
	}

	int query(int a, int b) {
		return query_sub(a, b, 0, 0, size);
	}

	// call ST.query(a, b, 0, 0, size);
	int query_sub(int a, int b, int k, int l, int r) {
		if ((r <= a)||(b <= l)) {
			return e;
		}
		if ((a <= l)&&(r <= b)) {
			return data[k];
		} else {
			int vl = query_sub(a, b, k * 2 + 1, l, (l + r) / 2);
			int vr = query_sub(a, b, k * 2 + 2, (l + r) / 2, r);
			return op(vl, vr);
		}
	}
};

ライブラリのテスト

いままでに紹介した DSL_2A問題DSL_2B問題をこのライブラリを使って解いてみます。

どちらのプログラムも、クラス SegmentTree は同一です。

DSL_2A問題を解くプログラム

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;
const int INF = (1U << 31) - 1;

class SegmentTree {
private:
	vector<int> data;
	int (*op)(int, int);
	int e;

public:
	int size;

	SegmentTree(int n, 	int (*_op)(int, int), int _e) {
		op = _op;
		e = _e;
		size = 1;
		while (size < n) {
			size *= 2;
		}
		data.assign(2 * size - 1, e);
	}

	int get(int k) {
		return data[k + size - 1];
	}

	void update(int k, int a) {
		k += size - 1;
		data[k] = a;
		while (k > 0) {
			k = (k - 1) / 2;
			data[k] = op(data[k * 2 + 1], data[k * 2 + 2]);
		}
	}

	int query(int a, int b) {
		return query_sub(a, b, 0, 0, size);
	}

	// call ST.query(a, b, 0, 0, size);
	int query_sub(int a, int b, int k, int l, int r) {
		if ((r <= a)||(b <= l)) {
			return e;
		}
		if ((a <= l)&&(r <= b)) {
			return data[k];
		} else {
			int vl = query_sub(a, b, k * 2 + 1, l, (l + r) / 2);
			int vr = query_sub(a, b, k * 2 + 2, (l + r) / 2, r);
			return op(vl, vr);
		}
	}
};

int op(int a, int b)
{
	return min(a, b);
}

int main()
{
	int n, q;
	cin >> n >> q;

	SegmentTree st(n, op, INF);
	for (int i = 0; i < q; ++i) {
		int command;
		cin >> command;
		if (command == 0) {
			int pos, x;
			cin >> pos >> x;
			st.update(pos, x);
		} else if (command == 1) {
			int L, R;
			cin >> L >> R;
			++R;
			cout << st.query(L, R) << endl;
		}
	}

	return 0;
}

DSL_2B問題を解くプログラム

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

class SegmentTree {
private:
	vector<int> data;
	int (*op)(int, int);
	int e;

public:
	int size;

	SegmentTree(int n, 	int (*_op)(int, int), int _e) {
		op = _op;
		e = _e;
		size = 1;
		while (size < n) {
			size *= 2;
		}
		data.assign(2 * size - 1, e);
	}

	int get(int k) {
		return data[k + size - 1];
	}

	void update(int k, int a) {
		k += size - 1;
		data[k] = a;
		while (k > 0) {
			k = (k - 1) / 2;
			data[k] = op(data[k * 2 + 1], data[k * 2 + 2]);
		}
	}

	int query(int a, int b) {
		return query_sub(a, b, 0, 0, size);
	}

	// call ST.query(a, b, 0, 0, size);
	int query_sub(int a, int b, int k, int l, int r) {
		if ((r <= a)||(b <= l)) {
			return e;
		}
		if ((a <= l)&&(r <= b)) {
			return data[k];
		} else {
			int vl = query_sub(a, b, k * 2 + 1, l, (l + r) / 2);
			int vr = query_sub(a, b, k * 2 + 2, (l + r) / 2, r);
			return op(vl, vr);
		}
	}
};

int op(int a, int b)
{
	return a + b;
}

int main()
{
	int n, q;
	cin >> n >> q;

	SegmentTree st(n, op, 0);
	for (int i = 0; i < q; ++i) {
		int command;
		cin >> command;
		if (command == 0) {
			int pos, x;
			cin >> pos >> x;
			--pos;
			st.update(pos, st.get(pos) + x);
		} else if (command == 1) {
			int L, R;
			cin >> L >> R;
			--L;
			cout << st.query(L, R) << endl;
		}
	}

	return 0;
}

最後に

この汎用化したライブラリを自分の道具箱に加えることにします。

COMMENT

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

CAPTCHA