Rustコトハジメ

プログラミング言語Rustと競プロに関する情報をお届けします。

Sparse Tableを実装した

Sparse Tableというのは、配列上のある範囲の最小値(あるいは最大値)を持ったインデックスを返すためのデータ構造(いわゆるRMQ)で、ナイーブに実装すると明らかにO(N)かかるが、Sparse Tableだと構築にO(NlogN)かかるがクエリ自体はO(1)で終わる。

その仕組みはクエリから考えるとわかりやすい。

  • あるインデックスiから連続した2k個の最小値インデックスがtable[k][i]に記憶されているとする
  • レンジ[a,b)が与えられた時、m=2^{floor(log(b-a))}を考える。この時、[a, a+m)と[b-m, b)は、レンジ[a,b)をカバーする。つまり、この中から最小値インデックスを探せばよい。これらはすでに計算済であるから、O(1)で終わる

tableの計算は、例えばtable[2]の情報はtable[1]から計算出来る。これを愚直にDPすればいい。

こういう感じで、2kに関する情報を予め計算してしまい、クエリの高速化に繋げるやり方は他にもLCAなどで見られるし、考えとしてはいかにもアルゴリズムって感じなので、他でも応用が効きそう。

assertが散りばめられてることからわかるように、軽くハマったため、カフェでFワードを連発しながら実装した。

struct SparseTable {
    data: Vec<i64>,
    log_table: Vec<usize>,
    table: Vec<Vec<usize>>,
}

impl SparseTable {
    fn new(data: Vec<i64>) -> Self {
        let n = data.len();
        let mut log_table = vec![0; n+1]; // log(k) (0<=k<=n)
        for i in 2..n+1 {
            log_table[i] = log_table[i >> 1] + 1;
        }
        // dbg!(&log_table);

        let mut table = vec![vec![n; n]; log_table[n]+1];
        // 2^k
        for i in 0..n {
            table[0][i] = i;
        }

        for k in 1..table.len() {
            // dbg!(&table);
            let half_jmp = 1 << (k-1);
            for i in 0..n {
                let first = table[k-1][i];
                table[k][i] = first;
                
                if i+half_jmp < n {
                    let second = table[k-1][i+half_jmp];
                    assert!(first < n);
                    assert!(second < n);
                    if data[first] <= data[second] {
                        table[k][i] = first;
                    } else {
                        table[k][i] = second;
                    }
                }
            }
        }
        // dbg!(&table);

        Self {
            data,
            log_table,
            table,
        }
    }

    // [a, b)
    fn query(&self, a: usize, b: usize) -> usize {
        let d = b - a;
        let k = self.log_table[d];
        let first = self.table[k][a];
        let second = self.table[k][b-(1<<k)];
        if self.data[first] <= self.data[second] {
            first
        } else {
            second
        }
    }
}