2017-11-20 1 views
1

내 코드 N 최고 값의 인덱스를 얻기 :NumPy와 배열

import numpy as np 
N = 2 
a = np.array([[0.5, 0.3, 0.2], 
       [0.2, 0.6, 0.2], 
       [0.3, 0.2, 0.7], 
       [np.nan, 0.2, 0.8],      
       [np.nan, np.nan, 0.8]      
       ]) 

ind = np.argsort(np.where(np.isnan(a), -1, a), axis=1)[:, -N:] 


a 
Out[2]: 
array([[ 0.5, 0.3, 0.2], 
     [ 0.2, 0.6, 0.2], 
     [ 0.3, 0.2, 0.7], 
     [ nan, 0.2, 0.8], 
     [ nan, nan, 0.8]]) 

ind 
Out[3]: 
array([[1, 0], 
     [2, 1], 
     [0, 2], 
     [1, 2], 
     [1, 2]], dtype=int64) 

IND를 [: 1] 가장 높고 IND 인 [: 0] 두 번째로 높은

경우를 제외 괜찮

마지막 열에 2 명이 앉는다. 나노가 두 번째로 큰 값을 무시하는 방법은 무엇입니까? 원하는 출력이 될 것이다 :

array([[1, 0], 
     [2, 1], 
     [0, 2], 
     [1, 2], 
     [nan, 2]], dtype=int64) 

보너스 질문 : 무작위로 [1 ,:]의 경우에 타이를 깰?

답변

1

Advanced-index하고, 우리에게 다음의 선택을 할 np.where와 함께 사용할 수있는 마스크를, 줄과 같이 할 NaNs 확인 - 배열 NaNfloat DTYPE로 변환 될 수있는 것을

In [244]: a_ind = a[np.arange(ind.shape[0])[:,None],ind] 

In [245]: mask = np.isnan(a_ind) 

In [246]: np.where(mask, np.nan, ind) 
Out[246]: 
array([[ 1., 0.], 
     [ 2., 1.], 
     [ 0., 2.], 
     [ 1., 2.], 
     [ nan, 2.]]) 

주 따라서 최종 출력도 float dtype이됩니다.