2013-06-17 3 views
4

숫자가 작아서 조건이 참인 범위/범위를 찾아야한다고 가정 해보십시오. 예를 들어, 내가 항목이 1보다 큰 스팬 찾기 위해 노력하고있는 다음과 같은 배열이 있습니다NumPy를 사용하여 조건이 참인 span을 찾습니다.

[0, 0, 0, 2, 2, 0, 2, 2, 2, 0] 

내가 인덱스를 찾을 필요가 있습니다 (시작, 중지) :

(3, 5) 
(6, 9) 

시작 및 종료 포 찾을 numpy.argminnumpy.argmax를 사용하여 배열을 통해 반복 다음

truth = data > threshold 

과 : 가장 빠른 것은 내가의 부울 배열을하고 구현할 수있었습니다 위치.

pos = 0 
    truth = container[RATIO,:] > threshold 

    while pos < len(truth): 
     start = numpy.argmax(truth[pos:]) + pos + offset 
     end = numpy.argmin(truth[start:]) + start + offset 
     if not truth[start]:#nothing more 
      break 
     if start == end:#goes to the end 
      end = len(truth) 
     pos = end 

는하지만 내 배열에서 위치의 수십억 내가 찾는거야 스팬 그냥 보통 행에 몇 위치 있다는 사실 너무 느리다. 누구든지이 기간을 찾는 더 빠른 방법을 알고 있습니까?

답변

5

어떻게하면 좋을까요? 먼저 부울 배열을, 당신은 :

In [11]: a 
Out[11]: array([0, 0, 0, 2, 2, 0, 2, 2, 2, 0]) 

In [12]: a1 = a > 1 

시프트 그것을 roll 사용하여 왼쪽에 하나 (각 인덱스의 다음 상태를 얻을) :이 non-zero입니다

In [13]: a1_rshifted = np.roll(a1, 1) 

In [14]: starts = a1 & ~a1_rshifted # it's True but the previous isn't 

In [15]: ends = ~a1 & a1_rshifted 

각의 시작 진정한 배치 (또는 각각 최종 배치) :

In [16]: np.nonzero(starts)[0], np.nonzero(ends)[0] 
Out[16]: (array([3, 6]), array([5, 9])) 

그리고이 함께 완봉 :

In [17]: zip(np.nonzero(starts)[0], np.nonzero(ends)[0]) 
Out[17]: [(3, 5), (6, 9)] 
+0

이것은 답으로 보입니다. 7 백만 포인트 이상의 실행은 원래 코드의 50 분에 비해 약 20 초 걸렸습니다. 감사! – ACEnglish

+1

'roll'이 사용 중이기 때문에 가장자리 케이스가 올바르지 않습니다. 예를 들어 스팬의 시작 위치가 0 인 경우 마지막 항목이 'True'일 때만 스팬을 얻을 수 있습니다. – Cuadue

+0

@Cuadue 그건 사실이야, 나는이 시점을 바로 잡아야한다고 생각하는 막연한 기억을 가지고있다. 어떻게 든 롤 첫 번째와 마지막 (?) 요소를 업데이트해야한다 ...: –

1

당신은 scipy 라이브러리에 액세스하는 경우 : 당신은 비 제로 값의 어떤 지역을 식별하는 데 scipy.ndimage.measurements.label을 사용할 수 있습니다

합니다. 각 요소의 값이 원래 배열의 범위 또는 범위의 ID 인 배열을 반환합니다.

그런 다음 scipy.ndimage.measurements.find_objects을 사용하면 해당 범위를 추출하는 데 필요한 조각을 반환 할 수 있습니다. 해당 슬라이스에서 직접 시작/끝 값에 액세스 할 수 있습니다. 당신의 예에서

:

from numpy import array 
from scipy.ndimage.measurements import label, find_objects 

data = numpy.array([0, 0, 0, 2, 2, 0, 2, 2, 2, 0]) 

labels, number_of_regions = label(a) 
ranges = find_objects(labels) 

for identified_range in ranges: 
    print identified_range[0].start, identified_range[0].stop 

당신은 볼 수 :이 도움이

3 5 
6 9 

희망!

+0

성능에 대한 업데이트 만 - 가치있는 것을 위해! 이것은 내 컴퓨터에서 ~ 0.6 초 안에 7 백만 포인트를 수행합니다. – djspoulter

관련 문제