2016-06-27 7 views
2

Tensorflow/python API를 사용하여 이미지를 포즈에 매핑하는 회귀 네트워크를 구현 중이며 FixedLengthRecordReader의 출력을 처리하려고합니다.슬라이스 Tensorflow FixedLengthRecordReader 값

나는 내 용도로 최소한 cifar10 example을 적용하려고합니다.

cifar10 예제는 원시 바이트를 읽고 해독 한 다음 분할합니다.

result.key, value = reader.read(filename_queue) 

# Convert from a string to a vector of uint8 that is record_bytes long. 
record_bytes = tf.decode_raw(value, tf.uint8) 

# The first bytes represent the label, which we convert from uint8->int32. 
result.label = tf.cast(
    tf.slice(record_bytes, [0], [label_bytes]), tf.int32) 

# The remaining bytes after the label represent the image, which we reshape 
# from [depth * height * width] to [depth, height, width]. 
depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]), 
         [result.depth, result.height, result.width]) 
# Convert from [depth, height, width] to [height, width, depth]. 
result.uint8image = tf.transpose(depth_major, [1, 2, 0]) 

나는 (pose_data, image_data)로 저장 데이터와 바이너리 파일의 목록에서 읽고 있습니다. 내 포즈 데이터가 float32이고 이미지 데이터가 uint8이기 때문에 먼저 슬라이스하고 캐스팅해야합니다. 불행하게도, reader.read의 값 결과는 0 차원 문자열 텐서이므로 슬라이스가 작동하지 않습니다.

key, value = reader.read(filename_queue) 
print value.dtype 
print value.get_shape() 

<dtype: 'string'> 
() 

tf.decode_raw (값 DTYPE)의 결과는 1 차원 배열이지만, 지정하는 DTYPE 필요하며 tf.string이 걸리는 것을 유효 타입이다.

디코딩 전에 슬라이스 할 수 있습니까? 아니면 다시 해독해야합니까? -> case -> slice -> recast? 거기에 또 다른 방법이 있습니까?

답변

0
영업뿐만 아니라 데이터가 여러 유형 (OP의 질문에)있을 때 작동하지 않습니다 두 번 디코드 솔루션 언급

cifar10 예와의 요인으로 n_uint8_vals이 필요 '줄을 서지'않습니다 (보다 일반적인 경우).

데이터가 예를 들어 경우 :

[float32][int16][int16] 

두번 디코드 작동합니다. 그러나 경우 데이터는 다음 tf.decode_raw는 반 float32의 오프셋 (offset)에 동의하지 않는 한이 작동하지 않습니다

[int16][float32][int16] 

. 이 경우 작업을 수행 무엇

tf.substr(),

result.key, value = reader.read(filename_queue) 

실제로 문자열 (또는 bytestring 만약에 당신) 자체를 분할 할 수있다

의 리턴 값이다.

-1

솔루션을 찾았습니다 : 두 번 디코딩하고 반을 버리십시오. 그다지 효율적이지는 않습니다. (누군가 더 좋은 해결책을 가지고 있다면 기쁜 마음으로 듣겠습니다.)하지만 효과가있는 것 같습니다.

key, value = reader.read(filename_queue) 
uint8_bytes = tf.decode_raw(value, tf.uint8) 
uint8_data = uint8_bytes[:n_uint8_vals] 
float32_bytes = tf.decode_raw(value, tf.float32) 
float32_start_index = n_uint8_vals // 4 
float32_data = float32_bytes[float32_start_index:] 

이 4

+0

링크가 깨졌습니다. 여기에 붙여 넣기가 가능한 경우 예제를보고 싶습니다. – Bastiaan