2016-06-03 3 views

답변

2

대부분의 TensorFlow 텐서 (tf.Tensor 개체)는 이 불변이며이므로 값을 할당 할 수 없습니다. 그러나 텐서를 tf.Variable으로 작성한 경우 Variable.assign()을 호출하여 텐서를 값에 할당 할 수 있습니다.

코드 tf.Variable (tf.trainable_variables() 목록에서)을 문자열 이름으로 불필요하게 변환하는 코드입니다. 대신, 다음을 수행 할 수 있습니다


# Get the 0th trainable variable. 
var = tf.trainable_variables()[0] 

# Create an op to assign a new value. 
assign_op = var.assign([[1.0, 2.0], [3.0, 4.0]]) 

# Actually run the assignment. 
sess.run(assign_op) 
그러나
, 당신의 의견에 따라, 여러 그래프를 (즉, self.graph 및 가 다른), 그래서 일반적인 솔루션은 내가 위에서 썼다는하지 않습니다 작업. 이 경우 두 가지 옵션이 있습니다

  1. 다른 그래프의 이름으로 변수를 가져옵니다 (NBgraph_2.get_collection('trainable_variables') 채워 된 경우에만 동작을, 당신이 구축 tf.import_graph_def()를 사용하는 경우가 작동하지 않습니다 그래프) :

    var_name = graph_1.get_collection('trainable_variables')[0].name 
    
    var_in_g2 = [v for v in graph_2.get_collection('trainable_variables') 
          if v.name == var_name][0] 
    
    assign_op = var_in_g2.assign([[1.0, 2.0], [3.0, 4.0]]) 
    sess.run(assign_op) 
    
  2. 다른 그래프의 이름으로 텐서 가져 오기를 사용 tf.assign() :

    var_name = graph_1.get_collection('trainable_variables')[0].name 
    
    var_in_g2 = graph_2.get_tensor_by_name(var_name) 
    
    assign_op = tf.assign(var_in_g2, [[1.0, 2.0], [3.0, 4.0]]) 
    sess.run(assign_op) 
    
+0

감사합니다. 문제는 여러 개의 그래프가 있고'tf.trainable_variables()'가 비어 있다는 것입니다. 그것에 대한 해결 방법이 있습니까? – MBZ

+0

아, 'self.graph'와 'graph'가 다른 것이 오타인지 확실하지 않았습니다. 나는 귀하의 사건을 커버하기 위해 답을 업데이트했습니다. – mrry

+0

감사합니다. Derek :) – MBZ

관련 문제