Learning process

After selecting CNN architecture, the process of learning consists of proper initialization and tunning network hyperparameters. This task usually requires some experiments, and doing this on Spark is no exception. Following parameters ought to be set:

  • number of epochs is the number of full training cycles with using the whole training data set,
  • updater, for all examples set to NESTEROVS with momentum set to 0.9,
  • l1 and/or l2 regularization,
  • learning rate, after several tests adjusted to quite high value 0.75 for start,
  • learning rate decay factors (number of steps with no increased accuracy after which learning rate will be decayed by a specified value)
  • number of iterations, for all examples set to 1,
  • batch size, twelve times number of cores available,
  • alpha parameter for leaky RELU neurons activation function, set to 0.02 (default is 0.01),
  • dropout value (only for fully connected layer) set to 0.5.

Some additional useful information about effective learning of neural networks and their evaluation can be found here and also here.

The whole learning process is coded in following method:

public void train(JavaRDD<DataSet> train, JavaRDD<DataSet> test) {

    int batchSize = 12 * cores;
    int lrCount = 0;
    double bestAccuracy = Double.MIN_VALUE;

    double learningRate = initialLearningRate;

    int trainCount = Long.valueOf(train.count()).intValue();
    log.info("Number of training images {}", trainCount);
    log.info("Number of test images {}", test.count());

    MultiLayerNetwork net = new MultiLayerNetwork(model.apply(learningRate, width, height, channels, numLabels));
    net.init();

    Map<Integer, Double> acc = new HashMap<>();
    for (int i = 0; i < epochs; i++) {

        SparkDl4jMultiLayer sparkNetwork = networkToSparkNetwork.apply(net);
        final MultiLayerNetwork nn = sparkNetwork.fitDataSet(train, batchSize, trainCount, cores);
        log.info("Epoch {} completed", i);

        JavaPairRDD<Object, Object> predictionsAndLabels = test.mapToPair(
                ds -> new Tuple2<>(label(nn.output(ds.getFeatureMatrix(), false)), label(ds.getLabels()))
                );
        MulticlassMetrics metrics = new MulticlassMetrics(predictionsAndLabels.rdd());
        double accuracy = 1.0 * predictionsAndLabels.filter(x -> x._1.equals(x._2)).count() / test.count();
        log.info("Epoch {} accuracy {} ", i, accuracy);
        acc.put(i, accuracy);
        predictionsAndLabels.take(10).forEach(t -> log.info("predicted {}, label {}", t._1, t._2));
        log.info("confusionMatrix {}", metrics.confusionMatrix());

        INDArray params = nn.params();
        if (accuracy > bestAccuracy) {
            bestAccuracy = accuracy;
            try {
                ModelSerializer.writeModel(nn, new File(workingDir, Double.toString(accuracy)), false);
            } catch (IOException e) {
                log.error("Error writing trained model", e);
            }
            lrCount = 0;
        } else {

            if (++lrCount % stepDecayTreshold == 0) {
                learningRate *= learningRateDecayFactor;
            }
            if (lrCount >= resetLearningRateThreshold) {
                lrCount = 0;
                learningRate = initialLearningRate;
            }
            if (learningRate < minimumLearningRate) {
                lrCount = 0;
                learningRate = initialLearningRate;
            }
            if (bestAccuracy - accuracy > downgradeAccuracyThreshold) {
                params = ModelLoader.load(workingDir, bestAccuracy);
            }
        }
        net = new MultiLayerNetwork(model.apply(learningRate, width, height, channels, numLabels));
        net.init();
        net.setParameters(params);
        log.info("Learning rate {} for epoch {}", learningRate, i + 1);
    }
    log.info("Training completed");

}

Decay of learning rate is in this example done manually, and every model with accuracy better then its predecessors is stored in working folder. Number of epochs as well as every other parameters can be provided for learning process:

NetworkTrainer trainer = new NetworkTrainer.Builder()
    .model(ModelLibrary.net1)
    .networkToSparkNetwork(net -> new SparkDl4jMultiLayer(sc, net))
    .numLabels(labels.size())
    .cores(NUM_CORES).build();
...
trainer.train(trainDataset, testDataset);

Training networks with selected architectures leads to models with 75% - 77% accuracy after several epochs. Usually after epoch 50 there is no improvement, regardless of learning rate that is being used.

results matching ""

    No results matching ""