2017-04-05 1 views
4

반복자에서 채워지는 큐를 만들고 싶습니다. 다음 MWE에서는 그러나, 항상 같은 값이 대기열에 :python iterator에서 채우기 큐

import tensorflow as tf 
import numpy as np 

# data 
imgs = [np.random.randn(i,i) for i in [2,3,4,5]] 

# iterate through data infinitly 
def data_iterator(): 
    while True: 
     for img in imgs: 
      yield img 

it = data_iterator() 

# create queue for data 
q = tf.FIFOQueue(capacity=5, dtypes=[tf.float64]) 

# feed next element from iterator 
enqueue_op = q.enqueue(list(next(it))) 

# setup queue runner 
numberOfThreads = 1 
qr = tf.train.QueueRunner(q, [enqueue_op] * numberOfThreads) 
tf.train.add_queue_runner(qr) 

# dequeue 
dequeue_op = q.dequeue() 
dequeue_op = tf.Print(dequeue_op, data=[dequeue_op], message="dequeue()") 

# We start the session as usual ... 
with tf.Session() as sess: 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    for i in range(10): 
     data = sess.run(dequeue_op) 
     print(data) 
. 
    coord.request_stop() 
    coord.join(threads) 

나는 반드시 feed_dict를 사용해야합니까? 그렇다면 어떻게 QueueRunner와 함께 사용해야합니까?

답변

3

(다음 (가)) 정확히 한 시간 목록을 실행합니다

enqueue_op = q.enqueue(list(next(it))) 

tensorflow를 실행. 그 후이 첫 번째 목록을 저장하고 enqueue_op을 실행할 때마다 q에 추가합니다. 이를 피하려면 자리 표시자를 사용해야합니다. Feeding placeholder는 tf.train.QueueRunner과 호환되지 않습니다. 대신 다음을 사용하십시오 :

import tensorflow as tf 
import numpy as np 
import threading 

# data 
imgs = [np.random.randn(i,i) for i in [2,3,4,5]] 

# iterate through data infinitly 
def data_iterator(): 
    while True: 
     for img in imgs: 
      yield img 

it = data_iterator() 

# create queue for data 
q = tf.FIFOQueue(capacity=5, dtypes=[tf.float64]) 

# feed next element from iterator 

img_p = tf.placeholder(tf.float64, [None, None]) 
enqueue_op = q.enqueue(img_p) 

dequeue_op = q.dequeue() 


with tf.Session() as sess: 
    coord = tf.train.Coordinator() 

    def enqueue_thread(): 
     with coord.stop_on_exception(): 
      while not coord.should_stop(): 
       sess.run(enqueue_op, feed_dict={img_p: list(next(it))}) 

    numberOfThreads = 1 
    for i in range(numberOfThreads): 
     threading.Thread(target=enqueue_thread).start() 



    for i in range(3): 
     data = sess.run(dequeue_op) 
     print(data)