2016-06-25 4 views
4

Tensorflow에서 사전 가중치를 사용하여 LSTM 모델을 구현하고 싶습니다. 이 무게는 Caffee 또는 Torch에서 올 수 있습니다.
파일 rnn_cell.py에 LSTM 셀이있는 것을 확인했습니다 (예 : rnn_cell.BasicLSTMCellrnn_cell.MultiRNNCell). 그러나이 LSTM 셀에 대해 사전에 가중치를 어떻게로드 할 수 있습니까?Tensorflow에서 미리 훈련 된 LSTM 모델 가중치를로드하는 방법

답변

0

미리 훈련 된 Caffe 모델을로드하기위한 솔루션입니다. this thread의 토론에서 참조한 full code here을 참조하십시오.

net_caffe = caffe.Net(prototxt, caffemodel, caffe.TEST) 
caffe_layers = {} 

for i, layer in enumerate(net_caffe.layers): 
    layer_name = net_caffe._layer_names[i] 
    caffe_layers[layer_name] = layer 

def caffe_weights(layer_name): 
    layer = caffe_layers[layer_name] 
    return layer.blobs[0].data 

def caffe_bias(layer_name): 
    layer = caffe_layers[layer_name] 
    return layer.blobs[1].data 

#tensorflow uses [filter_height, filter_width, in_channels, out_channels] 2-3-1-0 
#caffe uses [out_channels, in_channels, filter_height, filter_width] 0-1-2-3 
def caffe2tf_filter(name): 
    f = caffe_weights(name) 
    return f.transpose((2, 3, 1, 0)) 

class ModelFromCaffe(): 
    def get_conv_filter(self, name): 
     w = caffe2tf_filter(name) 
     return tf.constant(w, dtype=tf.float32, name="filter") 

    def get_bias(self, name): 
     b = caffe_bias(name) 
     return tf.constant(b, dtype=tf.float32, name="bias") 

    def get_fc_weight(self, name): 
     cw = caffe_weights(name) 
     if name == "fc6": 
      assert cw.shape == (4096, 25088) 
      cw = cw.reshape((4096, 512, 7, 7)) 
      cw = cw.transpose((2, 3, 1, 0)) 
      cw = cw.reshape(25088, 4096) 
     else: 
      cw = cw.transpose((1, 0)) 

     return tf.constant(cw, dtype=tf.float32, name="weight") 

images = tf.placeholder("float", [None, 224, 224, 3], name="images") 
m = ModelFromCaffe() 

with tf.Session() as sess: 
    sess.run(tf.initialize_all_variables()) 
    batch = cat.reshape((1, 224, 224, 3)) 
    out = sess.run([m.prob, m.relu1_1, m.pool5, m.fc6], feed_dict={ images: batch }) 
... 
+1

답변 해 주셔서 감사합니다. 그것은 나를 많이 돕는다. 그러나 RNN의 경우, 사전 훈련 된 가중치를 초기화하는 방법을 찾지 못했습니다. –

+0

ModelFromCaffe 클래스를 사용하여 변수를 만들 수 있습니다 (예 : 'fc6_W = tf.Variable (m.get_fc_weight ("fc6"), name = "fc6_W")'[여기에있는 문서를 참조하십시오.] (https://www.tensorflow.org/versions/r0.9/how_tos/variables/ index.html). – ssjadon

관련 문제