2017-10-11 4 views
0

다른 Keras 네트워크 (B)에서 Keras 네트워크 (A)를 사용하려고합니다. 나는 네트워크 A를 먼저 훈련시킨다. 그런 다음 네트워크 B에서이를 사용하여 일부 정규화를 수행합니다. 내부 네트워크 B evaluate 또는 predict을 사용하여 네트워크 A의 출력을 얻고 싶습니다. 불행히도이 함수는 숫자가 적은 배열을 기대하기 때문에 작동하지 않습니다. 대신 Tensorflow 변수를 입력으로받습니다. 여기keras forward pass (tensorflow 변수를 입력으로 사용)

내가 정의 regularizer 내부 네트워크 A를 사용하고 방법은 다음과 같습니다

class CustomRegularizer(Regularizer): 
    def __init__(self, model): 
     """model is a keras network""" 
     self.model = model 

    def __call__(self, x): 
     """Need to fix this part""" 
     return self.model.evaluate(x, x) 
내가 입력으로 Tensorflow 변수와 Keras 네트워크와 전진 패스를 계산할 수있는 방법

?

x = np.ones((1, 64), dtype=np.float32) 
model.predict(x)[:, :10] 

출력 : Tensorflow와

array([[-0.0244251 , 3.31579041, 0.11801113, 0.02281714, -0.11048832, 
     0.13053198, 0.14661783, -0.08456061, -0.0247585 , 
0.02538805]], dtype=float32) 

x = tf.Variable(np.ones((1, 64), dtype=np.float32)) 
model.predict_function([x]) 

출력 :

--------------------------------------------------------------------------- 
ValueError        Traceback (most recent call last) 
<ipython-input-92-4ed9d86cd79d> in <module>() 
     1 x = tf.Variable(np.ones((1, 64), dtype=np.float32)) 
----> 2 model.predict_function([x]) 

~/miniconda/envs/bolt/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs) 
    2266   updated = session.run(self.outputs + [self.updates_op], 
    2267        feed_dict=feed_dict, 
-> 2268        **self.session_kwargs) 
    2269   return updated[:len(self.outputs)] 
    2270 

~/miniconda/envs/bolt/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata) 
    776  try: 
    777  result = self._run(None, fetches, feed_dict, options_ptr, 
--> 778       run_metadata_ptr) 
    779  if run_metadata: 
    780   proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 

~/miniconda/envs/bolt/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata) 
    952    np_val = subfeed_val.to_numpy_array() 
    953   else: 
--> 954    np_val = np.asarray(subfeed_val, dtype=subfeed_dtype) 
    955 
    956   if (not is_tensor_handle_feed and 

~/miniconda/envs/bolt/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order) 
    529 
    530  """ 
--> 531  return array(a, dtype, copy=False, order=order) 
    532 
    533 

ValueError: setting an array element with a sequence. 
,536,913,632 예를 들어

, 여기에 내가 NumPy와 함께 무엇을 얻을 10

답변

0

tensorflow 변수가오고 어디 있는지 잘 모르겠지만, 그것은이있는 경우, 당신은이 작업을 수행 할 수 있습니다 sess가 tensorflow 세션입니다

model.predict([sess.run(x)]) 

, 즉 sess = tf.Session().

+0

네트워크에 내 질문에 어떻게 사용되는지에 대한 문맥을 추가했습니다. 나는 나의 문제를 해결하기 위해 당신의 답을 아직 적응할 수 없었다. –

+0

죄송하지만 디버깅을 돕기 위해 더 많은 세부 사항이 필요하다고 생각합니다. 내가 생각할 수있는 유일한 방법은'cr ([sess.run (x)])'와'cr = CustomRegularizer (model)'를 시도 할 수 있다는 것이다. –

관련 문제