Java 8 example of regression algorithms used for House prices predictions

Instead of using Scala case class regular Java bean can be used:

public class HouseModelJava implements LabeledPointConverter {

    private final Long id;
    private final Date date;
    private final Double price;
    private final Integer bedrooms;
    private final Double bathrooms;
    private final Integer sqft_living;
    private final Integer sqft_lot;
    private final Double floors;
    private final Integer waterfront;
    private final Integer view;
    private final Integer condition;
    private final Integer grade;
    private final Integer sqft_above;
    private final Integer sqft_basement;
    private final Integer yr_built;
    private final Integer yr_renovated;
    private final String zipcode;
    private final Double latitude;
    private final Double longitude;
    private final Integer sqft_living15;
    private final Integer sqft_lot15;

    public HouseModelJava(Long id, Date date, Double price, Integer bedrooms, Double bathrooms, Integer sqft_living,
            Integer sqft_lot, Double floors, Integer waterfront, Integer view, Integer condition, Integer grade,
            Integer sqft_above, Integer sqft_basement, Integer yr_built, Integer yr_renovated, String zipcode,
            Double latitude, Double longitude, Integer sqft_living15, Integer sqft_lot15) {
        super(); = id; = date;
        this.price = price;
        this.bedrooms = bedrooms;
        this.bathrooms = bathrooms;
        this.sqft_living = sqft_living;
        this.sqft_lot = sqft_lot;
        this.floors = floors;
        this.waterfront = waterfront;
        this.view = view;
        this.condition = condition;
        this.grade = grade;
        this.sqft_above = sqft_above;
        this.sqft_basement = sqft_basement;
        this.yr_built = yr_built;
        this.yr_renovated = yr_renovated;
        this.zipcode = zipcode;
        this.latitude = latitude;
        this.longitude = longitude;
        this.sqft_living15 = sqft_living15;
        this.sqft_lot15 = sqft_lot15;

    public HouseModelJava(String... row) {
        this(Long.parseLong(row[0]), new Date(parseDate(row[1])), Double.parseDouble(row[2]), Integer.parseInt(row[3]),
                Double.parseDouble(row[4]), Integer.parseInt(row[5]), Integer.parseInt(row[6]), Double
                        .parseDouble(row[7]), Integer.parseInt(row[8]), Integer.parseInt(row[9]), Integer
                        .parseInt(row[10]), Integer.parseInt(row[11]), Integer.parseInt(row[12]), Integer
                        .parseInt(row[13]), Integer.parseInt(row[14]), Integer.parseInt(row[15]), row[16], Double
                        .parseDouble(row[17]), Double.parseDouble(row[18]), Integer.parseInt(row[19]), Integer

    public LabeledPoint toLabeledPoint() {
        return new LabeledPoint(label(), features());

    public double label() {
        return price;

    public Vector features() {
        double[] features = { id, bedrooms, bathrooms, sqft_living, sqft_lot, floors, waterfront, view, condition,
                grade, sqft_above, sqft_basement, yr_built, yr_renovated, latitude, longitude, sqft_living15,
                sqft_lot15 };
        return Vectors.dense(features);

    static Long parseDate(String value) {
        try {
            return new java.text.SimpleDateFormat("yyyyMMdd'T'hhmmss").parse(value).getTime();
        } catch (ParseException e) {
            throw new RuntimeException(e);
    /* getters */

Function which creates linear regression model in Java 8 can look like that:

    static LinearRegressionModel createLinearRegressionModel(JavaRDD<LabeledPoint> rdd, Integer numIterations,
            Double stepSize) {
        return LinearRegressionWithSGD.train(rdd.rdd(),
                numIterations == null ? 100 : numIterations,
                stepSize == null ? 0.01 : stepSize);

Function which creates decision tree model in Java 8 for regression can look like that:

    static DecisionTreeModel createDecisionTreeRegressionModel(JavaRDD<LabeledPoint> rdd, Integer maxDepth,
            Integer maxBins) {
        String impurity = "variance";
        return DecisionTree.trainRegressor(rdd,
                impurity, maxDepth == null ? 10 : maxDepth,
                maxBins == null ? 20 : maxBins);

And the rest of the code:

        try (JavaSparkContext sc = new JavaSparkContext(configLocalMode(regressionApp()))) {
            JavaRDD<String> hdFile = localFile("house-data.csv", sc);

            JavaRDD<LabeledPoint> houses = -> s.split(",")).
                    filter(t -> !"date".equals(t[1]) ).
                    map(a -> new HouseModelJava(a).toLabeledPoint());

            StandardScalerModel scaler = new StandardScaler(true, true).fit( -> dp.features()).rdd());

            JavaRDD<LabeledPoint>[] split =
                    dp -> new LabeledPoint(dp.label(), scaler.transform(dp.features())))
                    .randomSplit(new double[] { .9, .1 }, 10204L);
            JavaRDD<LabeledPoint> train = split[0].cache();
            JavaRDD<LabeledPoint> test = split[1].cache();

            DecisionTreeModel model = createDecisionTreeRegressionModel(train, null, null);

                            x -> System.out.println(String.format("Predicted: %.1f, Label: %.1f",
                                    model.predict(x.features()), x.label()))

            JavaPairRDD<Object, Object> predictionsAndValues = test.mapToPair(
                    p -> new Tuple2<Object, Object>(model.predict(p.features()), p.label())

            System.out.println("Mean house price: " + test.mapToDouble(x -> x.label()).mean());
            System.out.println("Max prediction error: "
                    + predictionsAndValues.mapToDouble(
                            t2 -> Math.abs(Double.class.cast(t2._2) - Double.class.cast(t2._1))).max());

            RegressionMetrics metrics = new RegressionMetrics(predictionsAndValues.rdd());

            System.out.println(String.format("Mean Squared Error: %.2f", metrics.meanSquaredError()));
            System.out.println(String.format("Root Mean Squared Error: %.2f", metrics.rootMeanSquaredError()));
            System.out.println(String.format("Coefficient of Determination R-squared: %.2f", metrics.r2()));
            System.out.println(String.format("Mean Absoloute Error: %.2f", metrics.meanAbsoluteError()));
            System.out.println(String.format("Explained variance: %.2f", metrics.explainedVariance()));

Sample output can look like that:

Predicted: 481028,4, Label: 510000,0

Predicted: 732639,1, Label: 937000,0

Predicted: 437099,1, Label: 438000,0

Predicted: 632961,1, Label: 580500,0

Predicted: 328065,5, Label: 322500,0

Mean house price: 534198.5644402637

Max prediction error: 1710000.0

Mean Squared Error: 22243597500,87

Root Mean Squared Error: 149142,88

Coefficient of Determination R-squared: 0,81

Mean Absoloute Error: 87380,94

Explained variance: 110086143895,00

results matching ""

    No results matching ""