東京大学プログラミングコンテスト2011 L番目の数字

問題文 解説pdf AOJ

問題(概略)

木の各頂点に数値がついています。以下のクエリに答えなさい:
「頂点(v, w)を結ぶ経路上でL番目に小さい数値を出力せよ」

  • 1 <= サイズ <= 10^5
  • 1 <= クエリ数 <= 10^5
  • 1 <= 数値 <= 10^9

解答

解説pdfを読んで、最後に書いてある「別解」を実装してみた。
自分で考えて実装したので、解説者の解答と違う所もあるかもしれない。
解説pdfにはWaveletTreeとあるけれども、WaveletMatrixでも同じなのでWaveletMatrixベースでやった。
WaveletTreeやWaveletMatrixについては過去の記事やそこに書いてある参照を見ると少し参考になるかもしれない。

解説にもあるように、「各アルファベットが出現するだけじゃなく、減りもする」ことを実装しなければいけない。
これを自分は以下のように実装した。
まず、各位置に対して対象のアルファベットと一緒に"符号ビット"(出現する(正)か減る(負)か)を持つ。これはTreeの各段で全て持った。
そして、「ある範囲のアルファベットの数を数える」ということをするところを全て「rank(b, l, r) - 2 * bの符号のrank(l, r)」というふうに置き換えた。
このようにしたらrankやquantileはうまくいった。今回使うのはquantileなので良い。
ちなみにquantileというのは「区間[l,u)の中で(k+1)番目に(小さい/大きい)値を求める」というクエリ。この問題に適していることが明らかにわかる。
符号付きquantileは、それぞれのアルファベットはあたかも符号付き出現回数のみしか出現していないように振る舞う。負の出現回数のものがある場合に対しては未定義(実装次第では意味を持たせられるかも)。

解説にある次の"拡張"「1つの区間でなく,2つの区間に関するクエリを処理する」。
これは実際には処理するクエリはquantileだけでよい。2つの区間の和でのquantileはどうするか。
これは意外に簡単で、単に各区間ごとに範囲を持ち、カウントをするところでは総和を取ればできてしまう。

さて、そのデータ構造を用いてどのようにこの問題を解くか。
まず、解説にもあるように"符号付き"Euler-Tourをする。これを上記のWaveletMatrixにセットする。
この符号付き列は重要な性質を持っている:
全ての(p,v)という2頂点に対してpがvの親であるとき[+pの位置, +vの位置](+xの位置とは、頂点xの正の(その頂点に入ってくる)出現位置のこと)はすごい:
「p→vの経路上の頂点」がそれぞれ1回出現し、それ以外の頂点は「正の出現をしたあとに負の出現をする」。
つまり符号付きで考えると後者の頂点は打ち消し合って「p→vの経路上の頂点」のみが残る!

あとはLCAで二頂点の親を求めれば([+LCA(v,u)の位置, +vの位置] + (+LCA(v,u)の位置, +uの位置])が「v→uの経路上の頂点」となり、それをquantileで処理すればそれが答え!

コメント

面白かった。こういうものをじっくり考え、実装するのは好き。
これが「別解」であって、説明も少ないことも面白くさせた。
しかしこの問題の別解でない方「永続的データ構造」全くわかってないので、そっちのほうもいつか実装したい。

コード

本番のデータセットをダウンロードして、手動で何個かを通して正解した。コーナーケースで間違える可能性はある。
知らなかったのだけど、AOJに問題があったので提出した。ACした
速さは最適化していない。あとスタックが1MBだとあふれる。

#include <vector>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#define rep(i,n) for(int (i)=0;(i)<(int)(n);++(i))
#define rer(i,l,u) for(int (i)=(int)(l);(i)<=(int)(u);++(i))
#define reu(i,l,u) for(int (i)=(int)(l);(i)<(int)(u);++(i))
#define all(o) (o).begin(), (o).end()
#define mset(m,v) memset(m,v,sizeof(m))
using namespace std;

typedef unsigned int u32;
inline int popcount(u32 x) {
	x = x - ((x >> 1) & 0x55555555); 
	x = (x & 0x33333333) + ((x >> 2) & 0x33333333);
	return ((x + (x >> 4) & 0xF0F0F0F) * 0x1010101) >> 24;
}
//使わないselectは除去してある
struct FullyIndexableDictionary {
	int length, blockslength, count;
	vector<u32> blocks; vector<int> ranktable;
	FullyIndexableDictionary(int len): length(len) {
		blocks.resize((blockslength = (len + 31) / 32) + 1);
	}
	inline void set(int i) { blocks[i / 32] |= 1 << i % 32; }
	void build() {
		if(length == 0) { count = 0; return; }
		ranktable.assign(blockslength + 1, 0);
		int count0 = 0, count1 = 0;
		for(int i = 0; i < blockslength; i ++) {
			ranktable[i] = count1;
			count1 += popcount(blocks[i]);
			count0 = 32 * (i + 1) - count1;
		}
		ranktable[blockslength] = count1;
		count = count1;
	}
	inline int rank(int pos) const {	//[0..pos)の1の個数
		int block_idx = pos / 32;
		return ranktable[block_idx] + popcount(blocks[block_idx] & (1U << pos % 32)-1);
	}
	inline int rank(bool b, int pos) const { return b ? rank(pos) : pos - rank(pos); }
	inline int rank(bool b, int left, int right) const { return rank(b, right) - rank(b, left); }
};

//アルファベットが出現だけではなく消えることもある
//出現するか消えるかを「符号」(sign)と呼ぼう。出現が正, 消去が負
struct SignedWaveletMatrix {
	typedef unsigned int Val;
	static const Val UNDEFINED = Val(-1);
	int length, bitsize; Val maxval;
	vector<FullyIndexableDictionary> dicts, sign0dicts, sign1dicts;
	vector<int> mids;

	//signで出現か消えるか(0 or 1)
	void init(const vector<Val>& data, const vector<unsigned char>& signs) {
		length = data.size();
		maxval = *max_element(data.begin(), data.end());
		if(Val(1) << (8 * sizeof(Val) - 1) <= maxval) bitsize = 8 * sizeof(Val);
		else for(bitsize = 0; Val(1) << bitsize <= maxval; bitsize ++) ;
		dicts.assign(bitsize, length);
		sign0dicts.assign(bitsize, length);
		sign1dicts.assign(bitsize, length);
		mids.assign(bitsize, 0);
		vector<Val> datacurrent(data), datanext(length);
		vector<unsigned char> signscurrent(signs), signsnext(signs.size());
		for(int bit = 0; bit < bitsize; bit ++) {
			//dictsは符号無しの普通のもの
			//sign0, sign1はdictsと同じ位置の、それぞれの符号
			int pos = 0;
			for(int i = 0; i < length; i ++)
				if((datacurrent[i] >> (bitsize - bit - 1) & 1) == 0) {
					if(signsnext[pos] = signscurrent[i])
						sign0dicts[bit].set(i);
					datanext[pos ++] = datacurrent[i];
				}
			mids[bit] = pos;
			for(int i = 0; i < length; i ++)
				if((datacurrent[i] >> (bitsize - bit - 1) & 1) != 0) {
					dicts[bit].set(i);
					if(signsnext[pos] = signscurrent[i])
						sign1dicts[bit].set(i);
					datanext[pos ++] = datacurrent[i];
				}
			dicts[bit].build();
			sign0dicts[bit].build();
			sign1dicts[bit].build();
			signscurrent.swap(signsnext);
			datacurrent.swap(datanext);
		}
	}

private:
	inline int signrank(int bit, bool b, int pos) const {
		return (b ? sign1dicts : sign0dicts)[bit].rank(pos);
	}
	inline int bitrank_signed(int bit, bool b, int pos) const {
		return dicts[bit].rank(b, pos) - signrank(bit, b, pos) * 2;
	}
	inline int bitrank_signed(int bit, bool b, int left, int right) const {
		return bitrank_signed(bit, b, right) - bitrank_signed(bit, b, left);
	}
public:
	//符号つきで数える
	int signedcount(int left, int right) const {
		return (right - left) - sign0dicts[0].rank(true, left, right) * 2 - sign1dicts[0].rank(true, left, right) * 2;
	}

	//(k+1)番目に"小さい値"
	//範囲の中で負の回数出現する値がある場合の結果は未定義!
	//区間を複数に対応させるのは案外簡単。countするところで全部足すだけ
	Val quantile_ranges(const vector<int>& lefts0, const vector<int>& rights0, int k) const {
		int n = lefts0.size();
		int width = 0;
		for(int i = 0; i < n; i ++) width += signedcount(lefts0[i], rights0[i]);
		if(width <= k) { return UNDEFINED; }
		static vector<int> lefts, rights;	//自動変数だとメモリ確保の時間…と思ったけどこれだとどうだろう?
		lefts.assign(lefts0.begin(), lefts0.end());
		rights.assign(rights0.begin(), rights0.end());
		Val val = 0;
		for(int bit = 0; bit < bitsize; bit ++) {
			int count = 0;
			for(int i = 0; i < n; i ++) {
				count += bitrank_signed(bit, false, lefts[i], rights[i]);
			}
			bool dir = count <= k;
			val = val << 1 | (dir ? 1 : 0);
			if(dir) k -= count;
			for(int i = 0; i < n; i ++) {
				lefts[i] = dicts[bit].rank(dir, lefts[i]);
				rights[i] = dicts[bit].rank(dir, rights[i]);
				if(dir) lefts[i] += mids[bit], rights[i] += mids[bit];
			}
		}
		return val;
	}
	//符号を考えた他の関数も考えてみる。
	//rank_allは普通に実装できる。実装した。
	//selectは少なくとも少しは難しそう。複数の場所がk番目になりうるし。
	//dfs,bfsなんかも普通にできそうだけど実装はしてない。
};

typedef SignedWaveletMatrix::Val Val;

int N, Q;
Val x[111111];
vector<int> edges[111111];
int in_k[111111];
vector<Val> vals; vector<unsigned char> inout;

//LCAは初めて実装したので悪いところあるかも
int lca_depth[111111];
int lca_doubling[111111][17];	//lca_doubling[i][k]はiの2^k-親
vector<int> lca_tmp;

void lca_dfs(int parent, int i) {
	static int depth = 0;
	lca_depth[i] = depth;
	lca_tmp.push_back(i);
	for(int j = 1, k = 0; j < lca_tmp.size(); j *= 2, k ++) {
		lca_doubling[i][k] = lca_tmp[lca_tmp.size() - 1 - j];
	}
	for(vector<int>::iterator j = edges[i].begin(); j != edges[i].end(); ++ j) if(*j != parent) {
		depth ++;
		lca_dfs(i, *j);
		depth --;
	}
	lca_tmp.pop_back();
}

void lca_init(int root) {
	mset(lca_doubling, -1);
	lca_dfs(-1, 0);
}

int lca_query(int v, int u) {
	if(lca_depth[v] < lca_depth[u])
		swap(v, u);

	int level = 0;
	for(; 1 << level <= lca_depth[v]; level ++) ;

	for(int i = level-1; i >= 0; i --)
		if(lca_depth[v] - (1 << i) >= lca_depth[u])
			v = lca_doubling[v][i];

	if(v == u) return v;

	for(int i = level-1; i >= 0; i --)
		if(lca_doubling[v][i] != -1 && lca_doubling[v][i] != lca_doubling[u][i])
			v = lca_doubling[v][i], u = lca_doubling[u][i];

	return lca_doubling[v][0];
}

void euler_tour(int parent, int i) {
	static int k = 0;
	in_k[i] = k;
	vals.push_back(x[i]);
	inout.push_back(0);
	k ++;

		for(vector<int>::iterator j = edges[i].begin(); j != edges[i].end(); ++ j) if(*j != parent)
		euler_tour(i, *j);

	vals.push_back(x[i]);
	inout.push_back(1);
	k ++;
}

int main() {
	scanf("%d%d", &N, &Q);
	rep(i, N) scanf("%d", &x[i]);
	rep(i, N-1) {
		int a, b;
		scanf("%d%d", &a, &b);
		a --, b --;
		edges[a].push_back(b);
		edges[b].push_back(a);
	}
	lca_init(0);
	
	euler_tour(-1, 0);

	//これでは数値を10^9の範囲のままやってるけど、座標圧縮すれば速くなると思う
	SignedWaveletMatrix wm;
	wm.init(vals, inout);

	vector<int> lefts(2), rights(2);
	rep(i, Q) {
		int v, w, l;
		scanf("%d%d%d", &v, &w, &l);
		v --, w --, l --;
		
		int u = lca_query(v, w);
		
		//区間のそれぞれの値の出現回数は必ず正であることに注意
		//	子供なら上に登って降りてこないことはないので。
		//符号を考えると、考えたいパス以外の子供は正負で打ち消すことができる。
		//[in_x, in_y]は、xがyの親ならばx→yのパスに対応する。
		lefts[0] = in_k[u];
		rights[0] = in_k[v] + 1;
		//(in_x, in_y]はxを含めないバージョン
		lefts[1] = in_k[u] + 1;
		rights[1] = in_k[w] + 1;

		cout << wm.quantile_ranges(lefts, rights, l) << endl;
	}
	return 0;
}