Numpyの2次元行列からn番目に(小さな or 大きな)数値(およびそのインデックス)を抜き出す

以下のようなnumpy行列を考える。

k = np.array([
       [ 10, 50, 30],
       [ 40, 20, 10],
       [ 70, 80, 60]])Code language: PHP (php)

結論から言うとn番目に大きな数値を抜き出すにはnp.partitionとflatternを組み合わせて使うと良い

np.partition(k.flatten(), 0)
-> [10 50 30 40 20 10 70 80 60]
np.partition(k.flatten(), 1)
-> [10 10 30 40 20 50 70 80 60]
np.partition(k.flatten(), 2)
-> [10 10 20 40 30 50 70 80 60]Code language: CSS (css)

このようにnp.partitionはn番目に小さい(0からカウントを開始するので注意)数値をn番目の位置に”パーティション”として設置し, それより小さい数値を左側に, 大きい数値を右側に並べる。従ってパーティションの値を抜き出せばn番目を得られる。

従って0, 1, 2番目に小さい数値は以下のようにして得られる。

np.partition(k.flatten(), 0)[0]
np.partition(k.flatten(), 1)[1]
np.partition(k.flatten(), 2)[2]Code language: CSS (css)

逆に0, 1, 2番目に大きい数値を得たい場合は以下のようになる。前述とは異なりインデックスが一つずれるので注意。

np.partition(k.flatten(), -1)[-1]
np.partition(k.flatten(), -2)[-2]
np.partition(k.flatten(), -3)[-3]Code language: CSS (css)

Numpyの2次元行列からn番目に(小さな or 大きな)数値のインデックスを抜き出す

n番目に大きな数値のインデックスが欲しい場合。np.argpartitionを使う。先ほどと同様の2次元行列kに対して

np.argpartition(k.flatten(), 0)
-> [0 1 2 3 4 5 6 7 8]
#参考:np.partition(k.flatten(), 0) = [10 50 30 40 20 10 70 80 60]

np.argpartition(k.flatten(), 1)
-> [0 5 2 3 4 1 6 7 8]
#参考:np.partition(k.flatten(), 1) = [10 10 30 40 20 50 70 80 60]

np.argpartition(k.flatten(), 2)
-> [0 5 4 3 2 1 6 7 8]
#参考:np.partition(k.flatten(), 2) = [10 10 20 40 30 50 70 80 60]Code language: PHP (php)

要はnp.argpartitionはnp.partitionのインデックスを返すものである。

従って0, 1, 2番目に小さい数値は以下のようにして得られる。

np.argpartition(k.flatten(), 0)[0]
np.argpartition(k.flatten(), 1)[1]
np.argpartition(k.flatten(), 2)[2]Code language: CSS (css)

逆に0, 1, 2番目に大きい数値を得たい場合は以下のようになる。

np.argpartition(k.flatten(), -1)[-1]
np.argpartition(k.flatten(), -2)[-2]
np.argpartition(k.flatten(), -3)[-3]Code language: CSS (css)

さて, 以上のnp.argpartitionで得られるのはnp.flatten()によって1次元行列に『潰された』行列上のインデックスであり, 我々が知りたい元々の2次元行列のインデックスとは異なる。そこでnp.argpartitionで得た1次元行列のインデックスをnp.unravel_index()を使うことで2次元行列のインデックスに戻す。

i = np.argpartition(k.flatten(), -1)[-1] # 「潰した」1次元行列内の最大値を与えるインデックス7を取得
np.unravel_index(i, k.shape()) # インデックスiを本来の2次元行列の形k.shapeに戻す
-> (2, 1)Code language: PHP (php)

最大値を与えるインデックス(2, 1)がめでたく得られた。(やはり0からカウントを開始するので注意)

Leave a Reply

CAPTCHA