2017-12-07 1 views
0

안녕하세요, 임씨는 MNIST 및 softmax로 Tensorflow의 초보자 안내서를 수정하려고합니다. 이 자습서에는 10 개의 자릿수가 있습니다 (0-9 자릿수). 다른 데이터 세트 (EMNIST)를 사용하면 숫자와 문자에 대해 62 개의 클래스가 있습니다. 나는 orginal 한 예제의 모델이 무엇 은 : 28x28 이미지와 (10)의 총 픽셀 784 개 스탠드 클래스의 수는MNIST의 클래스 수를 변경하십시오. Tensorflow

x = tf.placeholder(tf.float32, [None, 784]) 
W = tf.Variable(tf.zeros([784, 10])) 
b = tf.Variable(tf.zeros([10])) 
y = tf.matmul(x, W) + b` 

입니다. 내가 원하는 건 :

x = tf.placeholder(tf.float32, [None, 784]) 
W = tf.Variable(tf.zeros([784, 62])) 
b = tf.Variable(tf.zeros([62])) 
y = tf.matmul(x, W) + b` 

62 클래스입니다. 하지만 다음 배치가 실행을 위해 호출되는 코드의이 부분에 도달하면 ... 역 추적 (마지막으로 가장 최근에 호출) 나는이 오류가

for _ in range(1000): 
batch_xs, batch_ys = mnist.train.next_batch(100) 
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 

을 :

File "calligraphy.py", line 77, in <module> 
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 
    File "C:\Users\Willy Barales\Anaconda3\lib\site-packages\tensorflow\python\platform\app.py", line 48, in run 
    _sys.exit(main(_sys.argv[:1] + flags_passthrough)) 
    File "calligraphy.py", line 64, in main 
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 
    File "C:\Users\Willy Barales\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 789, in run 
    run_metadata_ptr) 
    File "C:\Users\Willy Barales\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 975, in _run 
    % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape()))) 
ValueError: Cannot feed value of shape (100, 10) for Tensor 'Placeholder_1:0', which has shape '(?, 62)' 

이 예제의 데이터 세트를 변경하는 방법에 대한 아이디어가 있으십니까? .next_batch()가 구현 된 mnist.py 파일에서 뭔가를 변경해야합니까?

내가 아는 한 EMNIST는 MNIST와 완전히 동일한 형식을 가지고 있습니다. 미리 감사드립니다. 새로운 데이터 세트에

정보 : http://biometrics.nist.gov/cs_links/EMNIST/Readme.txt

+0

예, mnist.train.next_batch (100)는 (100, 10)의 batch_ys 크기를 반환하며 (100, 62)가되도록 기대하고 있습니다. – Nejla

+0

고마워요! 내가해야 할 일은 mnist.py 파일에서 label로부터 하나의 핫 벡터가 생성 된 부분을 편집하는 것인데, 이는 batch_ys에 해당하는 것이기 때문이다. 'def extract_labels (f, one_hot = False, num_classes = 62)' –

답변

0

사람들은 batch_ys에 해당하는 사람이기 때문에 내가의 mnist.py 파일의 핫 벡터는 라벨에서 생성 된 부분을 편집하는 것이었다해야 할 모든 Neijla의 계몽 덕분입니다.

def extract_labels(f, one_hot=False, num_classes=62) 

물론 내 질문에 명시된대로 모델의 클래스 수를 변경하십시오.

관련 문제