JUNのブログ

JUNのブログ

活動記録や技術メモ

弱い人と強い人のエラトステネスの篩

はじめに

素数列挙アルゴリズムの1つに エラトステネスの篩 というものがあります.

ja.wikipedia.org

このアルゴリズムを弱い自分がWikipediaの記事通りに実装したのと, 強い他の人が実装したのでは計算にかかる時間が大きく違ったので, なぜ計算時間に差があるのか, どのように改善出来るか調べてみようと思います.

今回実装した関数

自分が実装した関数.

import math

def get_primes(n):
    prime_list = []

    # 探索リストに2からnまでの整数を昇順で入れる
    target_list = [i for i in range(2, n)]

    n_sqrt = math.sqrt(n)

    # ふるい落とし操作を探索リストの先頭値がnの平方根に達するまで行う
    while target_list[0] < n_sqrt:
        # 探索リストの先頭の数を素数リストに移動し, その倍数を探索リストからふるい落とす
        prime = target_list[0]
        prime_list.append(prime)
        target_list.remove(prime)

        for num in target_list:
            if num % prime == 0:
                target_list.remove(num)

    # 探索リストに残った数を素数リストに移動して処理終了.
    prime_list.extend(target_list)

    return prime_list

他の人の実装

def get_sieve_of_eratosthenes(n):
    if not isinstance(n, int):
        raise TypeError('n is int type.')
    if n < 2:
        raise ValueError('n is more than 2')
    prime = [2]
    limit = int(n**0.5)
    data = [i + 1 for i in range(2, n, 2)]
    while True:
        p = data[0]
        if limit <= p:
            return prime + data
        prime.append(p)
        data = [e for e in data if e % p != 0]

Pythonで作るエラトステネスのふるい よりお借りしました.

計算時間

n = 10000 とした時に計算にかかった時間は以下のようになりました.

関数名 計算時間
get_primes 444[ms]
get_sieve_of_eratosthenes 11[ms]

詳しい計算時間

なぜ私の実装した エラトステネスの篩 関数 get_primes() は遅いのでしょうか?

line-profiler を使って各行の処理時間を見てみます.

Timer unit: 1e-06 s

Total time: 0.444782 s
File: test.py
Function: get_primes at line 3

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     3                                           @profile
     4                                           def get_primes(n):
     5         1          6.0      6.0      0.0      prime_list = []
     6
     7                                               # 探索リストに2からnまでの整数を昇順で入れる
     8         1       1322.0   1322.0      0.3      target_list = [i for i in range(2, n + 1)]
     9
    10         1          5.0      5.0      0.0      n_sqrt = math.sqrt(n)
    11
    12                                               # ふるい落とし操作を探索リストの先頭値がnの平方根に達するまで行う
    13        26         52.0      2.0      0.0      while target_list[0] < n_sqrt:
    14                                                   # 探索リストの先頭の数を素数リストに移動し, その倍数を探索リストからふるい落とす
    15        25         17.0      0.7      0.0          prime = target_list[0]
    16        25         31.0      1.2      0.0          prime_list.append(prime)
    17        25        137.0      5.5      0.0          target_list.remove(prime)
    18
    19     43684      28417.0      0.7      6.4          for num in target_list:
    20     43659      29833.0      0.7      6.7              if num % prime == 0:
    21      8770     384954.0     43.9     86.5                  target_list.remove(num)
    22
    23                                               # 探索リストに残った数を素数リストに移動して処理終了.
    24         1          7.0      7.0      0.0      prime_list.extend(target_list)
    25
    26         1          1.0      1.0      0.0      return prime_list

まず, target_list.remove(num) がめちゃくちゃ遅いことがわかります.

Pythonistaなら知らないと恥ずかしい計算量のはなし より以下の表を見てみましょう. 数式が変換されていないですが, 脳内で変換してください.

操作 平均時評価 最悪時評価
Copy $O(n)$ $O(n)$
Append $O(1)$ $O(1)$
Insert $O(n)$ $O(n)$
Get Item $O(1)$ $O(1)$
Set Item $O(1)$ $O(1)$
Delete Item $O(n)$ $O(n)$
Iteration $O(n)$ $O(n)$
Get Slice $O(k)$ $O(k)$
Del Slice $O(n)$ $O(n)$
Set Slice $O(k+n)$ $O(k+n)$
Extend $O(k)$ $O(k)$
Sort $O(n \log n)$ $O(n \log n)$
Multiply $O(nk)$ $O(nk)$
x in s $O(n)$
min(s), max(s) $O(n)$
Get Length $O(1)$ $O(1)$

この表の Delete Item の行から, Pythonのリストから要素を削除するには平均 $O(n)$ かかることがわかります.

Pythonではリストは内部的にはC言語の配列として表しているようです。そのため、先頭要素の追加や削除を行うとそれ以降の要素をすべて移動する必要があるため大きなコストがかかります。なので先頭に要素を追加したり削除する必要がある場合は、代わりにcollections.dequeを使用することを推奨しています。

Pythonistaなら知らないと恥ずかしい計算量のはなし

とのことなので, このあたりが改善出来そうです.

改善しよう

偶数を対象から排除する

get_primes() では愚直にエラトステネスの篩をそのまま表現したコードになっています. しかし, よく観察すると 2以外の偶数はすべて素数ではないです. つまり, 2以外の偶数を最初から対象に入れなければ対象のリストの要素数が半分になり, 対象のリストの作成処理と要素を削除する回数も半分になりそうです.

偶数を最初から対象に入れない場合のコードは以下のようになります.

import math


def get_primes(n):
    prime_list = [2]

    # 探索リストに2からnまでの奇数を昇順で入れる
    # 2以外の偶数は素数ではないことがわかっている
    target_list = [i + 1 for i in range(2, n, 2)]

    n_sqrt = math.sqrt(n)

    # ふるい落とし操作を探索リストの先頭値がnの平方根に達するまで行う
    while target_list[0] < n_sqrt:
        # 探索リストの先頭の数を素数リストに移動し, その倍数を探索リストからふるい落とす
        prime = target_list[0]
        prime_list.append(prime)
        target_list.remove(prime)

        for num in target_list:
            if num % prime == 0:
                target_list.remove(num)

    # 探索リストに残った数を素数リストに移動して処理終了.
    prime_list.extend(target_list)

    return prime_list


n = 10000
get_primes(n)
Timer unit: 1e-06 s

Total time: 0.178319 s
File: primes.py
Function: get_primes at line 4

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     4                                           @profile
     5                                           def get_primes(n):
     6         1          7.0      7.0      0.0      prime_list = []
     7
     8                                               # 探索リストに2からnまでの奇数を昇順で入れる
     9                                               # 2以外の偶数は素数ではないことがわかっている
    10         1        936.0    936.0      0.5      target_list = [i + 1 for i in range(2, n, 2)]
    11
    12         1          4.0      4.0      0.0      n_sqrt = math.sqrt(n)
    13
    14                                               # ふるい落とし操作を探索リストの先頭値がnの平方根に達するまで行う
    15        25         70.0      2.8      0.0      while target_list[0] < n_sqrt:
    16                                                   # 探索リストの先頭の数を素数リストに移動し, その倍数を探索リストからふるい落とす
    17        24         23.0      1.0      0.0          prime = target_list[0]
    18        24         41.0      1.7      0.0          prime_list.append(prime)
    19        24         96.0      4.0      0.1          target_list.remove(prime)
    20
    21     38683      28402.0      0.7     15.9          for num in target_list:
    22     38659      30482.0      0.8     17.1              if num % prime == 0:
    23      3771     118248.0     31.4     66.3                  target_list.remove(num)
    24
    25                                               # 探索リストに残った数を素数リストに移動して処理終了.
    26         1          9.0      9.0      0.0      prime_list.extend(target_list)
    27
    28         1          1.0      1.0      0.0      return prime_list

この改善によって 444[ms] だったものが 178[ms] まで速くなりました!!

list.remove(el) を使わずに書く

list.remove() の計算量が平均 $O(n)$ かかるということで, 偶数を対象から取り除いたことで list.remove()` が激遅なのが際立ってます. なんとかしましょう.

対策としては単純に list.remove() を使わずに実装すれば良いです. かんたんな話, list.remove() の計算量が $O(n)$ なのに対して, list.append() の計算量は最悪の場合でも $O(1)$ なので, 既存リストから複数の要素を削除する場合には, 愚直に削除するより, 新たにリストを作成し, その段階で条件を設けて要素を排除するほうが計算量が少なくて済みます.

方法 計算量 説明
リストからm個の要素を削除 $O(n)*m=O(nm)$ $O(n)$の処理をm回行うので$O(nm)$
新規にリストを作成し $O(n)+O(n) = O(2n) = O(n)$ リストの作成が$O(n)$で, ifでの判定が$O(1)$をn回で$O(n)$
import math


def get_primes(n):
    prime_list = [2]

    # 探索リストに2からnまでの奇数を昇順で入れる
    # 2以外の偶数は素数ではないことがわかっている
    target_list = [i + 1 for i in range(2, n, 2)]

    n_sqrt = math.sqrt(n)

    # ふるい落とし操作を探索リストの先頭値がnの平方根に達するまで行う
    while target_list[0] < n_sqrt:
        # 探索リストの先頭の数を素数リストに移動し, その倍数を探索リストからふるい落とす
        prime = target_list[0]
        prime_list.append(prime)

        # primeの倍数を探索リストからふるい落とす
        # 既存のリストから要素を削除するのではなく, 新規にリストを作成する.
        target_list = [i for i in target_list if i % prime != 0]

    # 探索リストに残った数を素数リストに移動して処理終了.
    prime_list.extend(target_list)

    return prime_list


n = 10000
get_primes(n)

プロファイルした結果は以下のようになりました.

Timer unit: 1e-06 s

Total time: 0.010486 s
File: primes.py
Function: get_primes at line 4

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     4                                           @profile
     5                                           def get_primes(n):
     6         1         11.0     11.0      0.1      prime_list = []
     7
     8                                               # 探索リストに2からnまでの奇数を昇順で入れる
     9                                               # 2以外の偶数は素数ではないことがわかっている
    10         1        978.0    978.0      9.3      target_list = [i + 1 for i in range(2, n, 2)]
    11
    12         1          4.0      4.0      0.0      n_sqrt = math.sqrt(n)
    13
    14                                               # ふるい落とし操作を探索リストの先頭値がnの平方根に達するまで行う
    15        25         37.0      1.5      0.4      while target_list[0] < n_sqrt:
    16                                                   # 探索リストの先頭の数を素数リストに移動し, その倍数を探索リストからふるい落とす
    17        24         24.0      1.0      0.2          prime = target_list[0]
    18        24         26.0      1.1      0.2          prime_list.append(prime)
    19
    20                                                   # primeの倍数を探索リストからふるい落とす
    21                                                   # 既存のリストから要素を削除するのではなく, 新規にリストを作成する.
    22        24       9399.0    391.6     89.6          target_list = [i for i in target_list if i % prime != 0]
    23
    24                                               # 探索リストに残った数を素数リストに移動して処理終了.
    25         1          7.0      7.0      0.1      prime_list.extend(target_list)
    26
    27         1          0.0      0.0      0.0      return prime_list

444[ms] → 178[ms] → 10[ms] まで改善できました!!

これは, 最初に紹介した get_get_sieve_of_eratosthenes() とほとんど同じ処理時間ですし, 関数内の処理もほとんど同じになりました!! (バリデーション処理があるので厳密には違いますが...)

感想

今回はエラトステネスの篩というアルゴリズムを実際に自分でWikipediaの記事通り実装して, そこから他の人の実装を参考にしながら最適化をしていきました.

自分はAtCoderはいつまでも灰色で, アルゴリズム弱弱人間なのでこのようなタイトルの記事になりました.

アルゴリズムや計算量の求め方はなんとなくでしか理解していないので, もし間違え等あれば私の成長のためにも優しく指摘してもらえると嬉しいです.

自分はローカルで記事を書いてはてなブログに貼り付けて記事を投稿しているのですが, はてなブログの数式の仕様がめんどくさすぎますね. 早くなんとかしてほしいものです.

では, 今日はこのへんで. お疲れ様でした👋