일부 심층 학습 알고리즘을 구현하기 위해 현재 nd4j 및 dl4j를 사용하고 있습니다. 그러나, 나는 datavec + dl4j를 처음부터 사용할 수 없습니다. 내가 b
하위 폴더 a
에있는 일부 그레이 스케일 28x28 이미지가, 이미지 폴더에서dl4j - 2 차원이 아닌 행렬의 행 수를 얻을 수 없습니다.
ImageConverter icv = new ImageConverter();
DataSetIterator dataSetIterator = icv.Convert();
log.info("Build model....");
int numEpochs = 10;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(1)
.learningRate(0.006)
.updater(Updater.NESTEROVS).momentum(0.9)
.regularization(true).l2(1e-4)
.list()
.layer(0, new ConvolutionLayer.Builder(5, 5)
.nIn(28 * 28)
.stride(1, 1)
.nOut(20)
.activation("identity")
.build())
.layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(24 * 24)
.nOut(2)
.activation("softmax")
.build())
.pretrain(false)
.backprop(true)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(1));
log.info("Train model....");
for(int i=0; i<numEpochs; i++){
model.fit(dataSetIterator);
}
: 여기
public class ImageConverter {
private static Logger log = LoggerFactory.getLogger(ImageConverter.class);
public DataSetIterator Convert() throws IOException, InterruptedException {
log.info("Start to convert images...");
File parentDir = new File(System.getProperty("user.dir"), "src/main/resources/images/");
ParentPathLabelGenerator parentPathLabelGenerator = new ParentPathLabelGenerator();
ImageRecordReader recordReader = new ImageRecordReader(28,28,1,parentPathLabelGenerator);
FileSplit fs = new FileSplit(parentDir);
InputSplit[] filesInDirSplit = fs.sample(null, 100);
recordReader.initialize(filesInDirSplit[0]);
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, 2, 1, 2);
log.info("Image convert finished.");
return dataIter;
}
}
메인 클래스입니다 : 여기
내 이미지 변환기입니다 각기.
그러나 Exception in thread "main" java.lang.IllegalStateException: Unable to get number of of rows for a non 2d matrix
이 발생합니다.
[[[...],
...
]]
=================OUTPUT==================
[[1.00, 0.00],
[1.00, 0.00]]
또한 dataSetIterator.next().get(0).toString()
의 출력이 예에서
[[[[...],
...
]]]
=================OUTPUT==================
[1.00, 0.00]
그리고 mnisterIterator위한
,mnisterIterator.next().toString()
같은 것을되어야이다
dataSetIterator.next().toString()
하여 데이터를 찾고, 그 일 같다 :
[[...]...]
=================OUTPUT==================
[[0.00, 0.00, 1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
...]
dataSetIterator
이 잘못된 형식의 데이터를 반환했다고 추측합니다.
아무도 모르게 해결할 수 있습니까?