2017-01-18 4 views
1

대기열을보다 자세히 이해하려고합니다. 아래의 코드를 사용하면 알파벳순으로 나열되어 있지 않으므로 출력 컬렉션이 영문자 순으로 표시됩니다. 이것은 초기 신기원을 제외한 모든 경우에 해당하는 것으로 보인다. 내가 뭔가를 오해하고 있니? 아래에 위의 변경tensorflow 대기열에서 순서 지정

from __future__ import absolute_import 
from __future__ import division 
from __future__ import print_function 

import time 
import tensorflow as tf 
import numpy as np 
import string 


# Basic model parameters as external flags. 
flags = tf.app.flags 
FLAGS = flags.FLAGS 
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') 
flags.DEFINE_integer('num_epochs', 2, 'Number of epochs to run trainer.') 
flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.') 
flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.') 
flags.DEFINE_integer('batch_size', 100, 'Batch size. ' 
        'Must divide evenly into the dataset sizes.') 
flags.DEFINE_string('train_dir', '/tmp/data', 
        'Directory to put the training data.') 
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data ' 
        'for unit testing.') 


def run_training(): 
    # Tell TensorFlow that the model will be built into the default Graph. 
    with tf.Graph().as_default(): 
    with tf.name_scope('input'): 
     # Input data 
     images_initializer = tf.placeholder(
      dtype=tf.int64, 
      shape=[52,1]) 
     input_images = tf.Variable(
      images_initializer, trainable=False, collections=[]) 

     image = tf.train.slice_input_producer(
      [input_images], num_epochs=2) 
     images = tf.train.batch(
      [image], batch_size=1) 

     alph_initializer = tf.placeholder(
      dtype=tf.string, 
      shape=[26,1]) 
     input_alph = tf.Variable(
      alph_initializer, trainable=False, collections=[]) 

     alph = tf.train.slice_input_producer(
      [input_alph], shuffle=False, capacity=26) 
     alphs = tf.train.batch(
      [alph], batch_size=1) 


    my_list = np.array(list(range(0,52))).reshape(52,1) 
    my_list_val = np.array(list(string.ascii_lowercase)).reshape(26,1) 


    # Create the op for initializing variables. 
    init_op = tf.initialize_all_variables() 

    # Create a session for running Ops on the Graph. 
    sess = tf.Session() 

    # Run the Op to initialize the variables. 
    sess.run(init_op) 
    sess.run(input_images.initializer, 
      feed_dict={images_initializer: my_list}) 
    sess.run(input_alph.initializer, 
      feed_dict={alph_initializer: my_list_val}) 

    sess.run(tf.local_variables_initializer()) 
    sess.run(tf.global_variables_initializer()) 
    # Start input enqueue threads. 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(sess=sess, coord=coord) 

    # And then after everything is built, start the training loop. 
    collection = [] 
    try: 
     step = 0 
     while not coord.should_stop(): 
     start_time = time.time() 

     # Run one step of the model. 
     integer = sess.run(image) 
     #print("Integer val", integer) 

     char = sess.run(alph) 
     collection.append(char[0][0]) 
     print("String val", char) 


     duration = time.time() - start_time 

    except tf.errors.OutOfRangeError: 
     print('Saving') 
     print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) 
    finally: 
     # When done, ask the threads to stop. 
     coord.request_stop() 
    print(str(collection)) 


    # Wait for threads to finish. 
    coord.join(threads) 
    sess.close() 


def main(_): 
    run_training() 


if __name__ == '__main__': 
    tf.app.run() 

답변

0

내 혼란을 정리할

try: 
     step = 0 
     while not coord.should_stop(): 
     start_time = time.time() 

     # Run one step of the model. 
     integer = sess.run(images) 
     #print("Integer val", integer) 

     char = sess.run(alphs) 
     collection.append(char[0][0]) 
     print("String val", char) 


     duration = time.time() - start_time 

    except tf.errors.OutOfRangeError: 
     print('Saving') 
     print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) 
    finally: 
     # When done, ask the threads to stop. 
     coord.request_stop() 
    print(str(collection))