2017-10-03 5 views
2

MNIST 데이터의 자체 버전을 만들려고합니다. 교육 자료와 테스트 데이터를 다음 파일로 변환했습니다. (관심있는 사람들은 내가이 보인다 JPG-PNG-to-MNIST-NN-Format을 사용하여 가까운 내가 목표로하고있는 무슨에 저를 얻을 않았다하십시오.)MNIST 데이터 세트 만들기 (MNIST 포맷과 동일)

test-images-idx3-ubyte.gz 
test-labels-idx1-ubyte.gz 
train-images-idx3-ubyte.gz 
train-labels-idx1-ubyte.gz 

그러나이 파일 형식과의 형식과 아주 동일하지 않습니다 MNIST 데이터 (mnist.pkl.gz). 나는 pkl이 데이터가 절인되었다는 것을 의미하는 것으로 이해하지만 나는 데이터를 절식하는 과정을 실제로 이해하지 못한다. 절어에 특정한 순서가 있는가? 누군가 내 데이터를 피클 링해야하는 코드를 제공 할 수 있습니까?

답변

1
import gzip 
import os 

import numpy as np 
import six 
from six.moves.urllib import request 

parent = 'http://yann.lecun.com/exdb/mnist' 
train_images = 'train-images-idx3-ubyte.gz' 
train_labels = 'train-labels-idx1-ubyte.gz' 
test_images = 't10k-images-idx3-ubyte.gz' 
test_labels = 't10k-labels-idx1-ubyte.gz' 
num_train = 17010 
num_test = 3010 
dim = 32*32 


def load_mnist(images, labels, num): 
    data = np.zeros(num * dim, dtype=np.uint8).reshape((num, dim)) 
    target = np.zeros(num, dtype=np.uint8).reshape((num,)) 

    with gzip.open(images, 'rb') as f_images,\ 
      gzip.open(labels, 'rb') as f_labels: 
     f_images.read(16) 
     f_labels.read(8) 
     for i in six.moves.range(num): 
      target[i] = ord(f_labels.read(1)) 
      for j in six.moves.range(dim): 
       data[i, j] = ord(f_images.read(1)) 

    return data, target 


def download_mnist_data(): 

    print('Converting training data...') 
    data_train, target_train = load_mnist(train_images, train_labels, 
              num_train) 
    print('Done') 
    print('Converting test data...') 
    data_test, target_test = load_mnist(test_images, test_labels, num_test) 
    mnist = {} 
    mnist['data'] = np.append(data_train, data_test, axis=0) 
    mnist['target'] = np.append(target_train, target_test, axis=0) 

    print('Done') 
    print('Save output...') 
    with open('mnist.pkl', 'wb') as output: 
     six.moves.cPickle.dump(mnist, output, -1) 
    print('Done') 
    print('Convert completed') 


def load_mnist_data(): 
    if not os.path.exists('mnist.pkl'): 
     download_mnist_data() 
    with open('mnist.pkl', 'rb') as mnist_pickle: 
     mnist = six.moves.cPickle.load(mnist_pickle) 
    return mnist 
download_mnist_data() 
+0

이것은 훌륭합니다. - 왜 아무도 그것을 좋아하지 않았는지 확실하지 않음 – javadba

관련 문제