はじめに
素数列挙アルゴリズムの1つに エラトステネスの篩 というものがあります.
このアルゴリズムを弱い自分がWikipediaの記事通りに実装したのと, 強い他の人が実装したのでは計算にかかる時間が大きく違ったので, なぜ計算時間に差があるのか, どのように改善出来るか調べてみようと思います.
今回実装した関数
自分が実装した関数.
import math def get_primes(n): prime_list = [] # 探索リストに2からnまでの整数を昇順で入れる target_list = [i for i in range(2, n)] n_sqrt = math.sqrt(n) # ふるい落とし操作を探索リストの先頭値がnの平方根に達するまで行う while target_list[0] < n_sqrt: # 探索リストの先頭の数を素数リストに移動し, その倍数を探索リストからふるい落とす prime = target_list[0] prime_list.append(prime) target_list.remove(prime) for num in target_list: if num % prime == 0: target_list.remove(num) # 探索リストに残った数を素数リストに移動して処理終了. prime_list.extend(target_list) return prime_list
他の人の実装
def get_sieve_of_eratosthenes(n): if not isinstance(n, int): raise TypeError('n is int type.') if n < 2: raise ValueError('n is more than 2') prime = [2] limit = int(n**0.5) data = [i + 1 for i in range(2, n, 2)] while True: p = data[0] if limit <= p: return prime + data prime.append(p) data = [e for e in data if e % p != 0]
Pythonで作るエラトステネスのふるい よりお借りしました.
計算時間
n = 10000 とした時に計算にかかった時間は以下のようになりました.
関数名 | 計算時間 |
---|---|
get_primes | 444[ms] |
get_sieve_of_eratosthenes | 11[ms] |
詳しい計算時間
なぜ私の実装した エラトステネスの篩 関数 get_primes()
は遅いのでしょうか?
line-profiler を使って各行の処理時間を見てみます.
Timer unit: 1e-06 s Total time: 0.444782 s File: test.py Function: get_primes at line 3 Line # Hits Time Per Hit % Time Line Contents ============================================================== 3 @profile 4 def get_primes(n): 5 1 6.0 6.0 0.0 prime_list = [] 6 7 # 探索リストに2からnまでの整数を昇順で入れる 8 1 1322.0 1322.0 0.3 target_list = [i for i in range(2, n + 1)] 9 10 1 5.0 5.0 0.0 n_sqrt = math.sqrt(n) 11 12 # ふるい落とし操作を探索リストの先頭値がnの平方根に達するまで行う 13 26 52.0 2.0 0.0 while target_list[0] < n_sqrt: 14 # 探索リストの先頭の数を素数リストに移動し, その倍数を探索リストからふるい落とす 15 25 17.0 0.7 0.0 prime = target_list[0] 16 25 31.0 1.2 0.0 prime_list.append(prime) 17 25 137.0 5.5 0.0 target_list.remove(prime) 18 19 43684 28417.0 0.7 6.4 for num in target_list: 20 43659 29833.0 0.7 6.7 if num % prime == 0: 21 8770 384954.0 43.9 86.5 target_list.remove(num) 22 23 # 探索リストに残った数を素数リストに移動して処理終了. 24 1 7.0 7.0 0.0 prime_list.extend(target_list) 25 26 1 1.0 1.0 0.0 return prime_list
まず, target_list.remove(num)
がめちゃくちゃ遅いことがわかります.
Pythonistaなら知らないと恥ずかしい計算量のはなし より以下の表を見てみましょう. 数式が変換されていないですが, 脳内で変換してください.
操作 | 平均時評価 | 最悪時評価 |
---|---|---|
Copy | $O(n)$ | $O(n)$ |
Append | $O(1)$ | $O(1)$ |
Insert | $O(n)$ | $O(n)$ |
Get Item | $O(1)$ | $O(1)$ |
Set Item | $O(1)$ | $O(1)$ |
Delete Item | $O(n)$ | $O(n)$ |
Iteration | $O(n)$ | $O(n)$ |
Get Slice | $O(k)$ | $O(k)$ |
Del Slice | $O(n)$ | $O(n)$ |
Set Slice | $O(k+n)$ | $O(k+n)$ |
Extend | $O(k)$ | $O(k)$ |
Sort | $O(n \log n)$ | $O(n \log n)$ |
Multiply | $O(nk)$ | $O(nk)$ |
x in s | $O(n)$ | |
min(s), max(s) | $O(n)$ | |
Get Length | $O(1)$ | $O(1)$ |
この表の Delete Item
の行から, Pythonのリストから要素を削除するには平均 $O(n)$ かかることがわかります.
Pythonではリストは内部的にはC言語の配列として表しているようです。そのため、先頭要素の追加や削除を行うとそれ以降の要素をすべて移動する必要があるため大きなコストがかかります。なので先頭に要素を追加したり削除する必要がある場合は、代わりにcollections.dequeを使用することを推奨しています。
とのことなので, このあたりが改善出来そうです.
改善しよう
偶数を対象から排除する
get_primes()
では愚直にエラトステネスの篩をそのまま表現したコードになっています. しかし, よく観察すると 2以外の偶数はすべて素数ではないです. つまり, 2以外の偶数を最初から対象に入れなければ対象のリストの要素数が半分になり, 対象のリストの作成処理と要素を削除する回数も半分になりそうです.
偶数を最初から対象に入れない場合のコードは以下のようになります.
import math def get_primes(n): prime_list = [2] # 探索リストに2からnまでの奇数を昇順で入れる # 2以外の偶数は素数ではないことがわかっている target_list = [i + 1 for i in range(2, n, 2)] n_sqrt = math.sqrt(n) # ふるい落とし操作を探索リストの先頭値がnの平方根に達するまで行う while target_list[0] < n_sqrt: # 探索リストの先頭の数を素数リストに移動し, その倍数を探索リストからふるい落とす prime = target_list[0] prime_list.append(prime) target_list.remove(prime) for num in target_list: if num % prime == 0: target_list.remove(num) # 探索リストに残った数を素数リストに移動して処理終了. prime_list.extend(target_list) return prime_list n = 10000 get_primes(n)
Timer unit: 1e-06 s Total time: 0.178319 s File: primes.py Function: get_primes at line 4 Line # Hits Time Per Hit % Time Line Contents ============================================================== 4 @profile 5 def get_primes(n): 6 1 7.0 7.0 0.0 prime_list = [] 7 8 # 探索リストに2からnまでの奇数を昇順で入れる 9 # 2以外の偶数は素数ではないことがわかっている 10 1 936.0 936.0 0.5 target_list = [i + 1 for i in range(2, n, 2)] 11 12 1 4.0 4.0 0.0 n_sqrt = math.sqrt(n) 13 14 # ふるい落とし操作を探索リストの先頭値がnの平方根に達するまで行う 15 25 70.0 2.8 0.0 while target_list[0] < n_sqrt: 16 # 探索リストの先頭の数を素数リストに移動し, その倍数を探索リストからふるい落とす 17 24 23.0 1.0 0.0 prime = target_list[0] 18 24 41.0 1.7 0.0 prime_list.append(prime) 19 24 96.0 4.0 0.1 target_list.remove(prime) 20 21 38683 28402.0 0.7 15.9 for num in target_list: 22 38659 30482.0 0.8 17.1 if num % prime == 0: 23 3771 118248.0 31.4 66.3 target_list.remove(num) 24 25 # 探索リストに残った数を素数リストに移動して処理終了. 26 1 9.0 9.0 0.0 prime_list.extend(target_list) 27 28 1 1.0 1.0 0.0 return prime_list
この改善によって 444[ms] だったものが 178[ms] まで速くなりました!!
list.remove(el) を使わずに書く
list.remove()
の計算量が平均 $O(n)$ かかるということで, 偶数を対象から取り除いたことで list
.remove()` が激遅なのが際立ってます. なんとかしましょう.
対策としては単純に list.remove()
を使わずに実装すれば良いです. かんたんな話, list.remove()
の計算量が $O(n)$ なのに対して, list.append()
の計算量は最悪の場合でも $O(1)$ なので, 既存リストから複数の要素を削除する場合には, 愚直に削除するより, 新たにリストを作成し, その段階で条件を設けて要素を排除するほうが計算量が少なくて済みます.
方法 | 計算量 | 説明 |
---|---|---|
リストからm個の要素を削除 | $O(n)*m=O(nm)$ | $O(n)$の処理をm回行うので$O(nm)$ |
新規にリストを作成し | $O(n)+O(n) = O(2n) = O(n)$ | リストの作成が$O(n)$で, ifでの判定が$O(1)$をn回で$O(n)$ |
import math def get_primes(n): prime_list = [2] # 探索リストに2からnまでの奇数を昇順で入れる # 2以外の偶数は素数ではないことがわかっている target_list = [i + 1 for i in range(2, n, 2)] n_sqrt = math.sqrt(n) # ふるい落とし操作を探索リストの先頭値がnの平方根に達するまで行う while target_list[0] < n_sqrt: # 探索リストの先頭の数を素数リストに移動し, その倍数を探索リストからふるい落とす prime = target_list[0] prime_list.append(prime) # primeの倍数を探索リストからふるい落とす # 既存のリストから要素を削除するのではなく, 新規にリストを作成する. target_list = [i for i in target_list if i % prime != 0] # 探索リストに残った数を素数リストに移動して処理終了. prime_list.extend(target_list) return prime_list n = 10000 get_primes(n)
プロファイルした結果は以下のようになりました.
Timer unit: 1e-06 s Total time: 0.010486 s File: primes.py Function: get_primes at line 4 Line # Hits Time Per Hit % Time Line Contents ============================================================== 4 @profile 5 def get_primes(n): 6 1 11.0 11.0 0.1 prime_list = [] 7 8 # 探索リストに2からnまでの奇数を昇順で入れる 9 # 2以外の偶数は素数ではないことがわかっている 10 1 978.0 978.0 9.3 target_list = [i + 1 for i in range(2, n, 2)] 11 12 1 4.0 4.0 0.0 n_sqrt = math.sqrt(n) 13 14 # ふるい落とし操作を探索リストの先頭値がnの平方根に達するまで行う 15 25 37.0 1.5 0.4 while target_list[0] < n_sqrt: 16 # 探索リストの先頭の数を素数リストに移動し, その倍数を探索リストからふるい落とす 17 24 24.0 1.0 0.2 prime = target_list[0] 18 24 26.0 1.1 0.2 prime_list.append(prime) 19 20 # primeの倍数を探索リストからふるい落とす 21 # 既存のリストから要素を削除するのではなく, 新規にリストを作成する. 22 24 9399.0 391.6 89.6 target_list = [i for i in target_list if i % prime != 0] 23 24 # 探索リストに残った数を素数リストに移動して処理終了. 25 1 7.0 7.0 0.1 prime_list.extend(target_list) 26 27 1 0.0 0.0 0.0 return prime_list
444[ms] → 178[ms] → 10[ms] まで改善できました!!
これは, 最初に紹介した get_get_sieve_of_eratosthenes()
とほとんど同じ処理時間ですし, 関数内の処理もほとんど同じになりました!! (バリデーション処理があるので厳密には違いますが...)
感想
今回はエラトステネスの篩というアルゴリズムを実際に自分でWikipediaの記事通り実装して, そこから他の人の実装を参考にしながら最適化をしていきました.
自分はAtCoderはいつまでも灰色で, アルゴリズム弱弱人間なのでこのようなタイトルの記事になりました.
アルゴリズムや計算量の求め方はなんとなくでしか理解していないので, もし間違え等あれば私の成長のためにも優しく指摘してもらえると嬉しいです.
自分はローカルで記事を書いてはてなブログに貼り付けて記事を投稿しているのですが, はてなブログの数式の仕様がめんどくさすぎますね. 早くなんとかしてほしいものです.
では, 今日はこのへんで. お疲れ様でした👋