Combinationの求め方まとめ
c++でnCkの計算でnの値に応じた2種類の導出方法をメモ。
- nの値が50程度でlong long型に収まるとき
このときはオーバーフローや時間制約を気にする必要はないが、階乗を愚直に計算することはできないので注意する。
nCk = n-1Ck-1 + n-1Ckというパスカルの三角形から導かれる公式を用いることでO(n^2)で計算が可能である。
#include <bits/stdc++.h> #define rep(i, m, n) for(int i = m; i < (n); i++) #define print(x) cout << (x) << endl; #define printa(x,n) for(int i = 0; i < n; i++){ cout << (x[i]) << " ";} cout << endl; #define printa2(x,m,n) for(int i = 0; i < m; i++){ for(int j = 0; j < n; j++){ cout << x[i][j] << " ";} cout << endl;} #define printp(x,n) for(int i = 0; i < n; i++){ cout << "(" << x[i].first << ", " << x[i].second << ") "; } cout << endl; #define INF (1e9) using namespace std; typedef long long ll; typedef struct{ int x; int y; } P; typedef struct{ ll to; ll cost; } edge; typedef pair<ll, ll> lpair; int N,K; ll v[51][51]; void nck(){ rep(i,0,N+1){ v[i][0] = 1; v[i][i] = 1; } rep(i,1,N+1){ rep(j,1,i){ v[i][j] = v[i-1][j-1] + v[i-1][j]; } } } int main(){ cin.tie(0); ios::sync_with_stdio(false); cin >> N >> K; nck(); print(v[N][K]); }
50C25で1e14程度のオーダーなのでN<=50程度の制約であればこちらの方法で問題ない。
- nが大きくmodを用いる場合
当然nが1000などになってくればオーバーフローするので1e9+7などの値でmodを取ることを要求してくるのが普通である。
また、nが100000程度になればO(n^2)で計算を行うことも不可能である。
このような場合、逆元を用いたmod計算を行うことで結果を導出することができる。
ここで用いるのがフェルマーの小定理。
aとpが互いに素であるときに
が成り立つというものである。詳しいことはこちら↓
mathtrain.jp
この式の両辺にをかけると
となる。したがって、の逆数のmodはのmodに等しいということがわかる。
これで準備が整う。こうして求まったmodの値を保持しておいて、の3つを掛け合わせて再びmodを取る(この時オーバーフローに注意)ことで答えが求まる。
その場で考えずに貼り付けるケースが多そうだから原理はそこまで重要じゃなさそうだけど。
以下実装の例。
#include <bits/stdc++.h> #define rep(i, m, n) for(int i = m; i < (n); i++) #define print(x) cout << (x) << endl; #define printa(x,n) for(int i = 0; i < n; i++){ cout << (x[i]) << " ";} cout << endl; #define printa2(x,m,n) for(int i = 0; i < m; i++){ for(int j = 0; j < n; j++){ cout << x[i][j] << " ";} cout << endl;} #define INF (1e9) typedef long long ll; typedef struct{ int x; int y; } P; using namespace std; const ll max_N = 110000; ll fac[max_N + 1000], facinv[max_N + 1000]; const ll MOD = 1e9 + 7; ll N,K; ll power(ll x, ll n){ if(n == 0) return 1LL; ll res = power(x * x % MOD, n/2); if(n % 2 == 1){ res = res * x % MOD; } return res; } ll nck(ll n, ll k){ if(k == 0 || n == k){ return 1; } return fac[n] * facinv[k] % MOD * facinv[n-k] % MOD; } ll npk(ll n, ll k){ if(k == 0 || n == k){ return 1; } return fac[n] * facinv[n-k] % MOD; } int main(){ cin.tie(0); ios::sync_with_stdio(false); cin >> N >> K; fac[0] = 0; fac[1] = 1; rep(i,2,max_N){ fac[i] = (fac[i-1] * i) % MOD; } rep(i,0,max_N){ facinv[i] = power(fac[i], MOD - 2); } print(nck(N,K)); }
facが通常の階乗のmodで、facinvが逆元を表している。
これを用いるとn = 100000程度であっても答えを得ることができる。