2017-12-06 1 views
0

TensorFlow에서 훈련 데이터 xs이 numpy NHCW 형식으로 있다고 가정 해보십시오. 내가 Tensorflow에 xs에서 배치를 샘플링 할 내가 tensor_list에서 대신 샘플링의tf.train.batch가 TensorFlow의 출력에 여분의 치수를 추가하는 이유는 무엇입니까?

xs = np.reshape(range(32), [4,2,2,2])  
tensor_list = [tf.convert_to_tensors(x) for x in xs] 
#x_tensor = tf.convert_to_tensors(xs) # tried this version too 
x_batch = tf.train.shuffle_batch(tensor_list, batch_size=3, capacity=50, min_after_dequeue=10) 

을했다,이 코드는 길이가 (이 경우 4) 데이터 포인트의 수를 동일 목록을 반환하고, 각 list 요소는 첫 번째 차원이 batch_size (이 경우 3) 인 텐서입니다. 개인적으로 직관적 인 결과는 x_batch이 4 차원 텐서이고 첫 번째 차원의 값이 batch_size이고 내용이 무작위로 샘플링됩니다. 그런 다음 sess.run(x_batch) 번으로 전화 할 때마다 다른 배치가 있습니다.

내가 잘못 한 부분을 알려주십시오.

답변

0

알아 내지 못했지만 아래에서 다른 해결책을 찾았습니다.

xs = np.reshape(range(32), [4,2,2,2])  
x_tensor = tf.convert_to_tensor(xs) 
dataset = tf.data.Dataset.from_tensor_slices(x_tensor) 
dataset = dataset.batch(2) 
dataset = dataset.shuffle(buffer_size=10000) 
iterator = dataset.make_initializable_iterator() 
sess.run(iterator.initializer) 

while True: 
    try: 
    next_element = iterator.get_next() 
    except tf.errors.OutOfRangeError: 
    print("End of dataset") # ==> "End of dataset" 

이어서이 뜻 출력은 2 사이즈마다 iterator.get_next() 임의 배치가 호출된다.

관련 문제