- Published on
Training of convolutional neural network for classification of handwritten digits
- Authors
- tesar-tech
This post is a simple example of CNN training. It uses the uncomplicated layer structure and training options for keeping things as simple as possible. The structure is simple enough to be trained fast on common CPU. Accuracy of trained network is over 98 %. Another post describes how to use this network to classify your own handwritten digit.
%% load data
%path to build-in digits dataset (part of MNIST dataset)
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet', 'nndemos','nndatasets', 'DigitDataset');
imds = imageDatastore(digitDatasetPath, 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
%Split data to train and test parts
[imdsTrain, imdsValidation] = splitEachLabel (imds, 0.7, 'randomize');
%% constructing cnn layer by layer
layers = [
imageInputLayer([28 28 1]) % input layer (grayscale image with size of 28x28 pixels)
convolution2dLayer(5,16,'Padding', 'same') % 16 convolution filters with size of 5
batchNormalizationLayer % normalization layer
reluLayer %relu for additional non-linearity (input lower than 0 is changed to 0, otherwise it still unchanged)
fullyConnectedLayer(10) % 10 - number of neurons
softmaxLayer % normalization
classificationLayer
];
%% specify options
options = trainingOptions('sgdm',...% type of solver
'Verbose', false,...% dont output in command window
'Plots', 'training-progress',...% plot nice graph with training progress
'MaxEpoch',5,...%use every training image 5 times
'ValidationData', imdsValidation); % use validation data
%% train network
net = trainNetwork (imdsTrain, layers, options); % training of CNN
Similar posts