Python 階乗・順列・二項係数を素数を割った余りを求めるライブラリ
階乗・順列・二項係数の剰余を求めるライブラリ
競プロでよく使うやつをライブラリ化しました。Moduloが素数でない時は上手く動きません。
class Factorial(): def __init__(self, mod=10**9 + 7): self.mod = mod self._factorial = [1] self._size = 1 self._factorial_inv = [1] self._size_inv = 1 def __call__(self, n): '''n! % mod ''' return self.fact(n) def fact(self, n): '''n! % mod ''' if n >= self.mod: return 0 self.make(n) return self._factorial[n] def fact_inv(self, n): '''n!^-1 % mod ''' if n >= self.mod: raise ValueError('Modinv is not exist! arg={}'.format(n)) self.make_inv(n) return self._factorial_inv[n] def comb(self, n, r): ''' nCr % mod ''' if r > n: return 0 t = self.fact_inv(n-r)*self.fact_inv(r) % self.mod return self(n)*t % self.mod def comb_with_repetition(self, n, r): ''' nHr % mod ''' t = self.fact_inv(n-1)*self.fact_inv(r) % self.mod return self(n+r-1)*t % self.mod def perm(self, n, r): ''' nPr % mod ''' if r > n: return 0 return self(n)*self.fact_inv(n-r) % self.mod @staticmethod def xgcd(a, b): ''' return (g, x, y) such that a*x + b*y = g = gcd(a, b) ''' x0, x1, y0, y1 = 0, 1, 1, 0 while a != 0: (q, a), b = divmod(b, a), a y0, y1 = y1, y0 - q * y1 x0, x1 = x1, x0 - q * x1 return b, x0, y0 def modinv(self, n): g, x, _ = self.xgcd(n, self.mod) if g != 1: raise ValueError('Modinv is not exist! arg={}'.format(n)) return x % self.mod def make(self, n): if n >= self.mod: n = self.mod if self._size < n+1: for i in range(self._size, n+1): self._factorial.append(self._factorial[i-1]*i % self.mod) self._size = n+1 def make_inv(self, n): if n >= self.mod: n = self.mod self.make(n) if self._size_inv < n+1: for i in range(self._size_inv, n+1): self._factorial_inv.append(self.modinv(self._factorial[i])) self._size_inv = n+1
解説
を_factorial, を_factorial_invに記録するみたいな感じです。
モジュロ逆数を求めるアルゴリズムをフェルマーの小定理に頼っているため、モジュロが素数でないと上手く動作しません。拡張ユークリッド互除法をつかって逆元を計算したほうが拡張性が高いですが、大目に見てください。
拡張ユークリッド互除法に変更しました!
他にも順列や二項係数、重複組合せなどをサポートしてます。
下のページの先頭クエリ100個で動作テストしました。
yukicoder.me
コードはこんな感じ。
class Factorial(): # 省略 fact = Factorial() comb = fact.comb perm = fact.perm comb_with_repetition = fact.comb_with_repetition inputs = open(0).read().split() for i in range(int(inputs[0])): q, n, r = inputs[i+1].replace('(', ',').replace(')', '').split(',') n, r = map(int, (n, r)) if q == 'C': print(comb(n, r)) elif q == 'P': print(perm(n, r)) else: print(comb_with_repetition(n, r))
その他
拡張ユークリッドを実装したら更新するかも。再帰だとPypyで遅くなるので非再帰で書きたいんですが、まあがんばります。
20200719 改善済み