의 이름으로 텐서의 값을 할당 :나는 다음과 같은 코드를 사용하여 TensorFlow 그래프 안에 텐서를 잡는거야
names = [var.name for var in self.graph.get_collection('trainable_variables')]
tensor = graph.get_tensor_by_name(names[0])
sess.run(tensor)
어떻게 tensor
의 값을 설정할 수 있습니까?
의 이름으로 텐서의 값을 할당 :나는 다음과 같은 코드를 사용하여 TensorFlow 그래프 안에 텐서를 잡는거야
names = [var.name for var in self.graph.get_collection('trainable_variables')]
tensor = graph.get_tensor_by_name(names[0])
sess.run(tensor)
어떻게 tensor
의 값을 설정할 수 있습니까?
대부분의 TensorFlow 텐서 (tf.Tensor
개체)는 이 불변이며이므로 값을 할당 할 수 없습니다. 그러나 텐서를 tf.Variable
으로 작성한 경우 Variable.assign()
을 호출하여 텐서를 값에 할당 할 수 있습니다.
코드 tf.Variable
(tf.trainable_variables()
목록에서)을 문자열 이름으로 불필요하게 변환하는 코드입니다. 대신, 다음을 수행 할 수 있습니다
# Get the 0th trainable variable.
var = tf.trainable_variables()[0]
# Create an op to assign a new value.
assign_op = var.assign([[1.0, 2.0], [3.0, 4.0]])
# Actually run the assignment.
sess.run(assign_op)
그러나, 당신의 의견에 따라, 여러 그래프를 (즉,
self.graph
및 가 다른), 그래서 일반적인 솔루션은 내가 위에서 썼다는하지 않습니다 작업. 이 경우 두 가지 옵션이 있습니다
다른 그래프의 이름으로 변수를 가져옵니다 (NBgraph_2.get_collection('trainable_variables')
채워 된 경우에만 동작을, 당신이 구축 tf.import_graph_def()
를 사용하는 경우가 작동하지 않습니다 그래프) :
var_name = graph_1.get_collection('trainable_variables')[0].name
var_in_g2 = [v for v in graph_2.get_collection('trainable_variables')
if v.name == var_name][0]
assign_op = var_in_g2.assign([[1.0, 2.0], [3.0, 4.0]])
sess.run(assign_op)
다른 그래프의 이름으로 텐서 가져 오기를 사용 tf.assign()
:
var_name = graph_1.get_collection('trainable_variables')[0].name
var_in_g2 = graph_2.get_tensor_by_name(var_name)
assign_op = tf.assign(var_in_g2, [[1.0, 2.0], [3.0, 4.0]])
sess.run(assign_op)
감사합니다. 문제는 여러 개의 그래프가 있고'tf.trainable_variables()'가 비어 있다는 것입니다. 그것에 대한 해결 방법이 있습니까? – MBZ
아, 'self.graph'와 'graph'가 다른 것이 오타인지 확실하지 않았습니다. 나는 귀하의 사건을 커버하기 위해 답을 업데이트했습니다. – mrry
감사합니다. Derek :) – MBZ