Dataset and its augmentation

As primary dataset for all experiments in this example, small subset of well-known CIFAR-10 dataset is used. It contains about 6k RGB images in size 32x32 pixels, divided into 4 distinct categories: bird, car, cat, dog, and can be found on github. Images are stored in text file with category name in first column, and image data in second column. Each channel (r,g,b) is separated, so first 32x32 bytes is for Red, next for Green, and last for Blue. Data set can be loaded and transformed into dl4j internal data formats with following code:

JavaRDD<String> raw = sc.textFile("data/images-data-rgb.csv");
String first = raw.first();

JavaPairRDD<String, String> labelData = raw
    .filter(f -> f.equals(first) == false)
    .mapToPair(r -> {
        String[] tab = r.split(";");
        return new Tuple2<>(tab[0], tab[1]);

JavaRDD<Tuple2<INDArray, double[]>> labelsWithData = labelData
    .map(t -> {
        INDArray label = FeatureUtil.toOutcomeVector(labels.get(t._1).intValue(), labels.size());
        double[] arr =" "))
        return new Tuple2<>(label, arr);

One step farther has to be taken, as dl4j expects their own "DataSet" data type. Here with splitting into training and test data:

JavaRDD<Tuple2<INDArray, double[]>>[] splited = labelsWithData.randomSplit(new double[] { .8, .2 }, seed);

JavaRDD<DataSet> testDataset = splited[1]
    .map(t -> {
        INDArray features = Nd4j.create(t._2, new int[] { 1, t._2.length });
        return new DataSet(features, t._1);
    }).cache();"Number of test images {}", testDataset.count());

Because used data set isn't huge, some augmentation can help achieving better accuracy. The simplest method is to add to it its images, but flipped horizontally. For trained network it should be irrelevant whether e.g. cat is looking to the left or right. Here, all training images are flipped, but of course some subset of initial training data set could be used. Input data values are between 0 and 255, and after simple normalization between -1.0 and 1.0. Procedure is as follows:

JavaRDD<DataSet> plain = splited[0]
    .map(t -> {
        INDArray features = Nd4j.create(t._2, new int[] { 1, t._2.length });
        return new DataSet(features, t._1);

JavaRDD<DataSet> flipped = splited[0]
    .map(t -> {
        double[] arr = t._2;
        int idx = 0;
        double[] farr = new double[arr.length];
        for (int i = 0; i < arr.length; i += trainer.width) {
            double[] temp = Arrays.copyOfRange(arr, i, i + trainer.width);
            for (int j = 0; j < trainer.height; ++j) {
                farr[idx++] = temp[j];
        INDArray features = Nd4j.create(farr, new int[] { 1, farr.length });
        return new DataSet(features, t._1);

JavaRDD<DataSet> trainDataset = plain.union(flipped).cache();"Number of train images {}", trainDataset.count());

For curiosity or for test purposes example image from data set can be turned into BufferedImage by using this code:

    static int IMAGE_DEPTH = 3;
    static int IMAGE_WIDTH = 32;
    static int IMAGE_HIGHT = 32;
    static int IMAGE_SIZE = 32 * 32;

    public static BufferedImage getImageFromArray(int[] pixels) {
        BufferedImage image = new BufferedImage(IMAGE_WIDTH, IMAGE_HIGHT, BufferedImage.TYPE_INT_RGB);
        for (int i = 0; i < pixels.length / IMAGE_DEPTH; ++i) {
            int rgb = new Color(pixels[i], pixels[i + IMAGE_SIZE], pixels[i + IMAGE_SIZE + IMAGE_SIZE]).getRGB();
            image.setRGB(i % IMAGE_WIDTH, i / IMAGE_HIGHT, rgb);
        return image;

and after that easily saved e.g. in PNG format:

  int[] pixels = sampleImage(new File("data/images-data-rgb.csv"));
  BufferedImage bi = getImageFromArray(pixels);
  ImageIO.write(bi, "PNG", new File("test.png"));

with function for selecting sample image from provided dataset:

    public static int[] sampleImage(File file) {
        try (BufferedReader br = new BufferedReader(new FileReader(file));) {
            return br.lines().findAny().map(l ->";")[1].split(" ")).mapToInt(Integer::parseInt).toArray()).get();            
        } catch (IOException e) {
        return null;

results matching ""

    No results matching ""