2017-05-01 1 views
0

많은 비슷한 질문을 읽고 올바르게 작동하지 않습니다.이전 세션의 int 변수를 tensorflow 1.1에서로드 할 수 없습니다.

필자의 모델은 잘 훈련되어 있으며 검사 점 파일은 모든 신기원으로 작성됩니다. 프로그램을 다시로드 한 후에는 epoch x에서 계속 진행할 수 있고 모든 반복마다 해당 신기원에서 인쇄 할 수 있습니다. 체크 포인트 파일 외부로 데이터를 간단하게 저장할 수는 있었지만 다른 모든 것들도 적절하게 저장된다는 확신을주기 위해이 작업을 수행하고 싶었습니다.

다시 시작할 때 epoch/global_step 변수의 값은 항상 0입니다.

import tensorflow as tf 
import numpy as np 
import tensorflow as tf 
import numpy as np 
# more imports 


def extract_number(f): # used to get latest checkpint file 
    s = re.findall("epoch(\d+).ckpt",f) 
    return (int(s[0]) if s else -1,f) 

def restore(init_op, sess, saver): # called to restore or just initialise model 
    list = glob(os.path.join("./params/e*")) 

    if list: 

     file = max(list,key=extract_number) 

     saver.restore(sess, file[:-5]) 


    sess.run(init_op) 
    return 


with tf.Graph().as_default() as g: 

    # build models 


    total_batch = data.train.num_examples/batch_size 

    epochLimit = 51 

    saver = tf.train.Saver() 

    init_op = tf.global_variables_initializer() 


    with tf.Session() as sess: 


     saver = tf.train.Saver() 

     init_op = tf.global_variables_initializer() 

     restore(init_op, sess, saver) 


     epoch = global_step.eval() 


     while epoch < epochLimit: 

      total_batch = data.train.num_examples/batch_size 

      for i in range(int(total_batch)): 

       sys.stdout.flush() 

       voxels = newData.eval() 

       batch_z = np.random.uniform(-1, 1, [batch_size, z_size]).astype(np.float32) 

       sess.run(opt_G, feed_dict={z:batch_z, train:True}) 
       sess.run(opt_D, feed_dict={input:voxels, z:batch_z, train:True}) 


       with open("out/loss.csv", 'a') as f: 
        batch_loss_G = sess.run(loss_G, feed_dict={z:batch_z, train:False}) 
        batch_loss_D = sess.run(loss_D, feed_dict={input:voxels, z:batch_z, train:False}) 
        msgOut = "Epoch: [{0}], i: [{1}], G_Loss[{2:.8f}], D_Loss[{3:.8f}]".format(epoch, i, batch_loss_G, batch_loss_D) 

        print(msgOut) 

      epoch=epoch+1 
      sess.run(global_step.assign(epoch)) 
      saver.save(sess, "params/epoch{0}.ckpt".format(epoch)) 

      batch_z = np.random.uniform(-1, 1, [batch_size, z_size]).astype(np.float32) 
      voxels = sess.run(x_, feed_dict={z:batch_z}) 

      v = voxels[0].reshape([32, 32, 32]) > 0 
      util.save_binvox(v, "out/epoch{0}.vox".format(epoch), 32) 

또한 하단의 assign을 사용하여 전역 단계 변수를 업데이트합니다. 어떤 아이디어? 어떤 도움이라도 대단히 감사하겠습니다.

답변

0

많은 것들을 시도했기 때문에 원본 코드가 여러 가지 이유로 잘못되었습니다. 첫 번째 응답자 Alexandre Passos가 유효한 포인트를 제공하지만, 게임을 변경 한 이유는 스코프 (아마도?)의 사용이기도합니다.

import tensorflow as tf 
import numpy as np 
# more imports 


def extract_number(f): # used to get latest checkpint file 
    s = re.findall("epoch(\d+).ckpt",f) 
    return (int(s[0]) if s else -1,f) 

def restore(sess, saver): # called to restore or just initialise model 


    list = glob(os.path.join("./params/e*")) 

    if list: 

     file = max(list,key=extract_number) 

     saver.restore(sess, file[:-5]) 
     return saver, True, sess 

    saver = tf.train.Saver() 
    init_op = tf.global_variables_initializer() 
    sess.run(init_op) 

    return saver, False , sess 


batch_size = 100 
learning_rate = 0.0001 
beta1 = 0.5 
z_size = 100 
save_interval = 1 

data = dataset.read() 

total_batch = data.train.num_examples/batch_size 

def fill_queue(): 
    for i in range(int(total_batch*epochLimit)): 
     sess.run(enqueue_op, feed_dict={batch: data.train.next_batch(batch_size)}) # runnig in seperate thread to feed a FIFOqueue 



with tf.variable_scope("glob"): 
    global_step = tf.get_variable(name='global_step', initializer=0,trainable=False) 

# build models 

epochLimit = 51 

saver = tf.train.Saver() 

with tf.Session() as sess: 

    saver,rstr,sess = restore(sess, saver) 



    with tf.variable_scope("glob", reuse=True): 
     epocht = tf.get_variable(name='global_step', trainable=False, dtype=tf.int32) 

    epoch = epocht.eval() 


    while epoch < epochLimit: 

     total_batch = data.train.num_examples/batch_size 

     for i in range(int(total_batch)): 

      sys.stdout.flush() 

      voxels = newData.eval() 

      batch_z = np.random.uniform(-1, 1, [batch_size, z_size]).astype(np.float32) 

      sess.run(opt_G, feed_dict={z:batch_z, train:True}) 
      sess.run(opt_D, feed_dict={input:voxels, z:batch_z, train:True}) 


      with open("out/loss.csv", 'a') as f: 
       batch_loss_G = sess.run(loss_G, feed_dict={z:batch_z, train:False}) 
       batch_loss_D = sess.run(loss_D, feed_dict={input:voxels, z:batch_z, train:False}) 
       msgOut = "Epoch: [{0}], i: [{1}], G_Loss[{2:.8f}], D_Loss[{3:.8f}]".format(epoch, i, batch_loss_G, batch_loss_D) 

       print(msgOut) 

     epoch=epoch+1 
     sess.run(global_step.assign(epoch)) 
     saver.save(sess, "params/epoch{0}.ckpt".format(epoch)) 

     batch_z = np.random.uniform(-1, 1, [batch_size, z_size]).astype(np.float32) 
     voxels = sess.run(x_, feed_dict={z:batch_z}) 

     v = voxels[0].reshape([32, 32, 32]) > 0 
     util.save_binvox(v, "out/epoch{0}.vox".format(epoch), 32) 
: 그것은 사람을 도움이된다면 다음

는 작업 업데이트 코드
1

복원 후 sess.run(init_op)을 호출하면 모든 변수가 초기 값으로 재설정됩니다. 그 줄을 주석 처리하고 모든 것이 작동해야합니다.

관련 문제