Binary Search

二分查找

二分查找模版

def binary_search(arr, target):
    if not arr:
        return -1
    left = 0
    right = len(arr) - 1
    while left <= right:
        mid = left + ((right - left)>>1)
        pivot = arr[mid]
        if pivot < target:
            left = mid + 1
        elif pivot > target:
            right = mid - 1
        else
            pass  # ???
    ret = -1  # ???
    return ret

二分查找实现

#!/usr/bin/env python3
# -*- coding:utf-8 -*-

import bisect


def binary_search(arr, target):
    """ 二分查找不存在返回 -1 """
    ret = -1
    if not arr:
        return ret
    arr_len = len(arr)
    left = 0
    right = arr_len
    while left < right:
        mid = left + ((right - left)>>1)
        v = arr[mid]
        if v < target:
            left = mid + 1
        elif v > target:
            right = mid
        else:
            return mid
    return ret


def low_bound(arr, target):
    """ 返回左边界 第一个大于等于 """
    ret = -1
    if not arr:
        return ret
    arr_len = len(arr)
    left = 0
    right = arr_len
    while left < right:
        mid = left + ((right - left)>>1)
        v = arr[mid]
        # v >= target right = mid
        if v < target:
            left = mid + 1
        elif v > target:
            right = mid
        else:
            right = mid
    ret = left
    return ret


def upper_bound(arr, target):
    """ 返回边界 第一个大于"""
    ret = -1
    if not arr:
        return ret
    arr_len = len(arr)
    left = 0
    right = arr_len
    while left < right:
        mid = left + ((right - left)>>1)
        v = arr[mid]
        # v <= target left = mid + 1
        if v < target:
            left = mid + 1
        elif v > target:
            right = mid
        else:
            left = mid + 1
    ret = right
    return ret


def low_bound_reverse(arr, target):
    """ 逆序数组,返回左边界 第一个小于等于 """
    ret = -1
    if not arr:
        return ret
    arr_len = len(arr)
    left = 0
    right = arr_len
    while left < right:
        mid = left + ((right - left)>>1)
        v = arr[mid]
        # v >= target left = mid + 1
        if v >= target:
            left = mid + 1
        elif v < target:
            right = mid
    ret = left
    return ret


def upper_bound_reverse(arr, target):
    """ 逆序数组,返回边界 第一个小于"""
    ret = -1
    if not arr:
        return ret
    arr_len = len(arr)
    left = 0
    right = arr_len
    while left < right:
        mid = left + ((right - left)>>1)
        v = arr[mid]
        # v <= target right = mid
        if v > target:
            left = mid + 1
        elif v <= target:
            right = mid
    ret = right
    return ret


def _main():
    # arr = list(range(1, 11))
    # 数组从小到大
    arr = list(range(11))
    pivot = 7
    arr[8] = pivot
    print('origin: {}'.format(arr))
    for i in (-1, 5, 7, 8, 11):
        ret1 = binary_search(arr, i)
        ret2 = bisect.bisect_left(arr, i)
        print('{} diff {} {}'.format(i, ret1, ret2))

    for i in (-1, 5, 7, 8, 11):
        ret1 = low_bound(arr, i)
        ret2 = bisect.bisect_left(arr, i)
        print('{} diff {} {}'.format(i, ret1, ret2))

    for i in (-1, 5, 7, 8, 11):
        ret1 = upper_bound(arr, i)
        ret2 = bisect.bisect_right(arr, i)
        print('{} diff {} {}'.format(i, ret1, ret2))

    arr.sort(reverse=True)

    print('sort origin: {}'.format(arr))
    for i in (-1, 5, 7, 8, 11):
        ret1 = low_bound_reverse(arr, i)
        # ret2 = bisect.bisect_left(arr, i)
        # ret3 = bisect.bisect_right(arr, i)
        # ret2 = ret2 if ret2 < ret3 else ret3
        ret2 = 0
        print('{} diff {} {}'.format(i, ret1, ret2))

    for i in (-1, 5, 7, 8, 11):
        ret1 = upper_bound_reverse(arr, i)
        # ret2 = bisect.bisect_left(arr, i)
        # ret3 = bisect.bisect_right(arr, i)
        # ret2 = ret2 if ret2 > ret3 else ret3
        ret2 = 0
        print('{} diff {} {}'.format(i, ret1, ret2))


if __name__ == '__main__':
    _main()

Note

二分查找的三个步骤: 1. 预处理:如果序列未排序,则先进行排序 2. 二分查找:使用循环或递归将中间值元素与目标元素进行比较,将区间划分为两个子区间,然后再符合条件的其中一个子区间内进行寻找,直至循环或递归结束。 3. 后处理:在循环或递归完成后,需要对剩余区间的元素中确定符合条件的元素

  1. left < right //在相邻的时候退出避免死循环,左闭右开区(左闭右开区间既符合直觉,又可以省去代码中大量的 +1-1edge case 检查)
  2. left + ((right-left)>>1) //找中间值,避免 mid 溢出
    1. 上位中位数 mid = left + ((right-left)>>1)
    2. 下位中位数 mid = left + ((right-left-1)>>1)
  3. arr[mid] ==, <, > //判断要根据找 target 是第一次出现还是最后一次出现来决定把 mid 给 left 还是 right,核心在于何种方式最大限度的缩小搜索空间

C++ 标准库中采用左闭右开区间,提供了两种边界查找函数,如何使用两种函数实现四种边界查询?

  1. lower_bound 查找 x >= target 的下界,若为 right 则不存在
  2. upper_bound 查找 x > target 的下界,若为 right 则不存在
  3. lower_bound - 1 查找 x < target 的上界,若为 left - 1 则不存在
  4. upper_bound - 1 查找 x <= target 的上界,若为 left - 1 则不存在

Reference