2016-11-28 2 views
1

나는 Tensorflow가 그래프를 관리하는 것처럼 보이는 펑키 한 일이 있음을 깨달았습니다.Tensorflow는 어떻게 그래프를 관리합니까?

모델을 빌드 (및 다시 빌드)하기가 번거롭기 때문에 클래스에 내 사용자 정의 모델을 랩핑하기로 결정하여 다른 곳에서 쉽게 다시 인스턴스화 할 수있었습니다.

(원래 장소에서) 코드를 연습하고 테스트했을 때 그래프의 변수를로드 한 코드에서 변수 재 정의와 기타 모든 종류의 이상한 오류가 발생했습니다. 이것은 (비슷한 질문에 대한 마지막 질문에서) 모든 것이 두 번 호출된다는 암시였습니다.

추적 기능을 수행 한 후로드 된 코드를 사용하는 방식으로 변경되었습니다. 그것은 너무

class MyModelUser(object): 
    def forecast(self): 
     # .. build the model in the same way as in the training code 
     # load the model checkpoint 
     # call the "predict" function on the model 
     # manipulate the prediction and return it 

같은 구조를 가지고 그리고 몇 가지 코드에 내가

def test_the_model(self): 
    model_user = MyModelUser() 
    print(model_user.forecast()) # 1 
    print(model_user.forecast()) # 2 

을했고 (분명히) 나는이 개 예측이를 볼 것으로 예상 MyModelUser 사용하는 클래스 내에서 사용 된 라고 불렀다. 대신, 최초의 예측이라고 예상대로 작동했지만 두 번째 호출이 변수 재사용의 TON에 ValueError 이들 중 하나의 예를 던졌다되었다이었다

ValueError: Variable weight_def/weights already exists, disallowed. Did you mean to set reuse=True in VarScope? 

나는이 시리즈를 추가하여 오류를 진압하기 위해 관리 변수를 만들기 위해 get_variable을 사용한 try/except 블록 중 하나를 선택하고 범위에서 reuse_variables을 호출 한 다음 예외를 제외하고는 get_variable을 호출합니다. 내가 말한 변덕에

tensorflow.python.framework.errors.NotFoundError: Tensor name "weight_def/weights/Adam_1" not found in checkpoint files 

이 있었다 중 하나가 불쾌한 오류의 새로운 세트에 가져 "나는 __init__로 모델링 건물 코드를 이동하는 경우 그래서 그 한 번만 내장 무엇?"

내 새로운 모델 사용자 : 지금

class MyModelUser(object): 
    def __init__(self): 
     # ... build the model in the same way as in the training code 
     # load the model checkpoint 


    def forecast(self): 
     # call the "predict" function on the model 
     # manipulate the prediction and return it 

및 예상대로

def test_the_model(self): 
    model_user = MyModelUser() 
    print(model_user.forecast()) # 1 
    print(model_user.forecast()) # 2 

작품, 오류없이 두 개의 예측을 인쇄. 이것은 내가 변수 재사용을 제거 할 수 있다고 믿게한다.

내 질문은 :

왜 수정 했습니까? 이론적으로 그래프는 원래의 예측 방법에서 매번 재구성되어야하므로 하나 이상의 그래프를 생성해서는 안됩니다. 함수가 완료된 후에도 Tensorflow가 그래프를 유지합니까? 왜 창조 코드를 __init__으로 옮기는 것이 효과가 있었습니까? 이로 인해 절망적으로 혼란스러워졌습니다.

답변

2

기본적으로 TensorFlow는 TensorFlow API를 처음 호출 할 때 생성되는 단일 글로벌 tf.Graph 인스턴스를 사용합니다.tf.Graph을 명시 적으로 만들지 않으면 모든 작업, 텐서 및 변수가 해당 기본 인스턴스에 만들어집니다. 즉, 코드에서 model_user.forecast()으로 호출 할 때마다 동일한 전역 그래프에 작업이 추가되므로 작업이 다소 낭비됩니다.

가 여기에 행동의 두 가지 코스 (적어도) :

  • 이상적인 작업을 MyModelUser.__init__()는 예측을 수행하는 데 필요한 모든 작업에 전체 tf.Graph를 구축하도록 코드를 재구성하는 것, MyModelUser.forecast()은 기존 그래프에서 sess.run() 호출을 간단하게 수행합니다. 이상적으로는 tf.Session도 하나만 만들면됩니다. TensorFlow가 세션의 그래프에 대한 정보를 캐시하고 실행이 더 효율적이기 때문에 이상적입니다.

  • — 덜 침습적이지만 — 변화가 MyModelUser.forecast()에 대한 모든 호출에 대해 tf.Graph 새로운를 생성하는 것 아마 덜 효율적. 그것은 훨씬 상태가 MyModelUser.__init__() 방식으로 만드는 방법 질문에서 불분명하지만 다른 그래프에서 두 통화를 넣어 다음과 같이 뭔가를 할 수 :

    def test_the_model(self): 
        with tf.Graph(): # Create a local graph 
        model_user_1 = MyModelUser() 
        print(model_user_1.forecast()) 
        with tf.Graph(): # Create another local graph 
        model_user_2 = MyModelUser() 
        print(model_user_2.forecast()) 
    
0

TF에는 새로운 작업 등이 추가되는 기본 그래프가 있습니다. 함수를 두 번 호출하면 동일한 그래프에 동일한 작업을 두 번 추가하게됩니다. 따라서 그래프를 한 번 작성하고 여러 번 평가하십시오 (이전과 마찬가지로 "정상적인"방법이기도합니다). 또는 물건을 변경하려면 reset_default_graph https://www.tensorflow.org/versions/r0.11/api_docs/python/framework.html#reset_default_graph을 사용하여 그래프를 재설정하면됩니다. 신선한 상태.

관련 문제