2017-12-19 2 views
0

Tensorflow Dataset API를 사용하여 폴더 당 하나의 배치 (이미지가 들어있는 각 폴더)를 만들고 싶습니다.ListDirectory가있는 Tensorflow 데이터 세트 API

import tensorflow as tf 
import os 
import pdb 

def parse_file(filename): 
    image_string = tf.read_file(filename) 
    image_decoded = tf.image.decode_png(image_string) 
    image_resized = tf.image.resize_images(image_decoded, [48, 48]) 
    return image_resized #, label 

def parse_dir(frame_dir): 
    filenames = tf.gfile.ListDirectory(frame_dir) 
    batch = tf.constant(5) 
    batch = tf.map_fn(parse_file, filenames) 
    return batch 

directory = "../Detections/NAC20171125" 
# filenames = tf.constant([os.path.join(directory, f) for f in os.listdir(directory)]) 
frames = [os.path.join(directory, str(f)) for f in range(10)] 


dataset = tf.data.Dataset.from_tensor_slices((frames)) 
dataset = dataset.map(parse_dir) 

dataset = dataset.batch(256) 
iterator = dataset.make_initializable_iterator() 
next_element = iterator.get_next() 


with tf.Session() as sess: 
    sess.run(iterator.initializer) 
    while True: 
     try: 
      batch = sess.run(next_element) 
      print(batch.shape) 
     except tf.errors.OutOfRangeError: 
      break 

그러나, (parse_dir에서) tf.gfile.ListDirectory 대신 텐서의 정상적인 문자열을 기대 : 나는 다음과 같은 간단한 코드가 있습니다. 이제 오류는

TypeError: Expected binary or unicode string, got <tf.Tensor 'arg0:0' shape=() dtype=string> 

이 문제를 해결하는 간단한 방법이 있습니까? 여기

답변

2

문제는 tf.gfile.ListDirectory() 파이썬 문자열을 기대 파이썬 함수이며, parse_dir()frame_dir 인수가 tf.Tensor 것입니다. 따라서 디렉터리에있는 파일을 나열하려면 동등한 TensorFlow 작업이 필요하며 tf.data.Dataset.list_files() (기반은 tf.matching_files())은 아마도 가장 비슷한 것일 수 있습니다.

directory = "../Detections/NAC20171125" 
frames = [os.path.join(directory, str(f)) for f in range(10)] 

# Start with a dataset of directory names. 
dataset = tf.data.Dataset.from_tensor_slices(frames) 

# Maps each subdirectory to the list of files in that subdirectory and flattens 
# the result. 
dataset = dataset.flat_map(lambda dir: tf.data.Dataset.list_files(dir + "/*")) 

# Maps each filename to the parsed and resized image data. 
dataset = dataset.map(parse_file) 

dataset = dataset.batch(256) 

iterator = dataset.make_initializable_iterator() 
next_element = iterator.get_next()