다른 Keras 네트워크 (B)에서 Keras 네트워크 (A)를 사용하려고합니다. 나는 네트워크 A를 먼저 훈련시킨다. 그런 다음 네트워크 B에서이를 사용하여 일부 정규화를 수행합니다. 내부 네트워크 B evaluate
또는 predict
을 사용하여 네트워크 A의 출력을 얻고 싶습니다. 불행히도이 함수는 숫자가 적은 배열을 기대하기 때문에 작동하지 않습니다. 대신 Tensorflow 변수를 입력으로받습니다. 여기keras forward pass (tensorflow 변수를 입력으로 사용)
내가 정의 regularizer 내부 네트워크 A를 사용하고 방법은 다음과 같습니다
class CustomRegularizer(Regularizer):
def __init__(self, model):
"""model is a keras network"""
self.model = model
def __call__(self, x):
"""Need to fix this part"""
return self.model.evaluate(x, x)
내가 입력으로 Tensorflow 변수와 Keras 네트워크와 전진 패스를 계산할 수있는 방법
?
x = np.ones((1, 64), dtype=np.float32)
model.predict(x)[:, :10]
출력 : Tensorflow와
array([[-0.0244251 , 3.31579041, 0.11801113, 0.02281714, -0.11048832,
0.13053198, 0.14661783, -0.08456061, -0.0247585 ,
0.02538805]], dtype=float32)
x = tf.Variable(np.ones((1, 64), dtype=np.float32))
model.predict_function([x])
출력 :
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-92-4ed9d86cd79d> in <module>()
1 x = tf.Variable(np.ones((1, 64), dtype=np.float32))
----> 2 model.predict_function([x])
~/miniconda/envs/bolt/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
2266 updated = session.run(self.outputs + [self.updates_op],
2267 feed_dict=feed_dict,
-> 2268 **self.session_kwargs)
2269 return updated[:len(self.outputs)]
2270
~/miniconda/envs/bolt/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
776 try:
777 result = self._run(None, fetches, feed_dict, options_ptr,
--> 778 run_metadata_ptr)
779 if run_metadata:
780 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
~/miniconda/envs/bolt/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
952 np_val = subfeed_val.to_numpy_array()
953 else:
--> 954 np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)
955
956 if (not is_tensor_handle_feed and
~/miniconda/envs/bolt/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order)
529
530 """
--> 531 return array(a, dtype, copy=False, order=order)
532
533
ValueError: setting an array element with a sequence.
,536,913,632 예를 들어
, 여기에 내가 NumPy와 함께 무엇을 얻을 10
네트워크에 내 질문에 어떻게 사용되는지에 대한 문맥을 추가했습니다. 나는 나의 문제를 해결하기 위해 당신의 답을 아직 적응할 수 없었다. –
죄송하지만 디버깅을 돕기 위해 더 많은 세부 사항이 필요하다고 생각합니다. 내가 생각할 수있는 유일한 방법은'cr ([sess.run (x)])'와'cr = CustomRegularizer (model)'를 시도 할 수 있다는 것이다. –