Rustコトハジメ

プログラミング言語Rustと競プロに関する情報をお届けします。

二分探索を一般化するライブラリを作りました

二分探索の一般化

www.rustforbeginners.com

で話したとおり、二分探索はTFの列に出現する最初のTを探す操作に他なりません。従って、インデックスを入力として、ブール値を出力とする関数を与えてあげれば、最初のTないしはFとTの境界を出力するという関数が書けるはずです。二分探索は頻出して、その度に書くのが嫌なのでライブラリとして一般化することにしました。

struct BinarySearch<F> {
    p: F,
    lower: usize,
    upper: usize,
}

impl <F: Fn(usize) -> bool> BinarySearch<F> {
    fn search(&self) -> (Option<usize>, usize) {
        let mut lb = self.lower;
        let mut ub = self.upper+2;
        while ub - lb > 1 {
            let mid = (lb+ub)/2;
            if (self.p)(mid-1) {
                ub = mid;
            } else {
                lb = mid;
            }
        }
        let former = if lb == self.lower {
            None
        } else {
            Some(lb-1)
        };
        let latter = ub-1;
        (former, latter)
    }
}

この関数は先頭にFが詰まっていて、それからTが続くような遅延計算される配列pを入力として、FとTの境界を出力します。配列が全部Tの場合、境界の前側は-1で、Rustではusizeは正の数しかとることが出来ませんから、Noneを返すことにしています。

C++の標準ライブラリにはlower_boundとupper_boundというのが定義されていて、その考え方が他の言語のライブラリにも見られることがありますが、私はわかりにくいと思っていて、むしろ「TF境界を求める」と一つの概念でとらえた方が楽ですし、問い系の二分探索にも自然に適用出来ます。

問題を解いてみる

E - 最悪の教頭 (Worst Head Teacher)

は、コンテスト中に正しい方針にたどり着いていたのですが、値の大小とインデックスの方向が逆だったりしたこともあり少しわかりにくく、コンテスト中には実装に失敗しました。二分探索をもっと抽象的にやらないと自分の頭では、少しややこしい論理になると間違えてしまうことがわかったので、抽象化しようと思いました。

このライブラリのテストも兼ねて、リベンジしました。

fn solve() {
    input! {
        N: usize, Q: usize,
        D: [i64; N],
        TLR: [(i64,i64,i64); Q],
    }

    let mut d_acc = vec![0; N+1];
    d_acc[0] = 0;
    for i in 1..N+1 {
        d_acc[i] = d_acc[i-1] + D[i-1] - 1;
    }
    
    // principal: 0
    // kids: 1..N+1
    let pos = |t: i64, i: usize| {
        if t < d_acc[i] {
            -(i as i64)
        } else {
            -(i as i64) + (t - d_acc[i])
        }
    };

    for (t,l,r) in TLR {
        let r_bs = BinarySearch {
            p: |i: usize| {
                pos(t,i) <= r
            },
            lower: 0,
            upper: N,
        };
        let l_bs = BinarySearch {
            p: |i: usize| {
                pos(t,i) < l
            },
            lower: 0,
            upper: N,
        };
        println!("{}", l_bs.search().1 - r_bs.search().1);
    }
}