Java 8 example of regression algorithms used for House prices predictions
Instead of using Scala case class regular Java bean can be used:
public class HouseModelJava implements LabeledPointConverter {
private final Long id;
private final Date date;
private final Double price;
private final Integer bedrooms;
private final Double bathrooms;
private final Integer sqft_living;
private final Integer sqft_lot;
private final Double floors;
private final Integer waterfront;
private final Integer view;
private final Integer condition;
private final Integer grade;
private final Integer sqft_above;
private final Integer sqft_basement;
private final Integer yr_built;
private final Integer yr_renovated;
private final String zipcode;
private final Double latitude;
private final Double longitude;
private final Integer sqft_living15;
private final Integer sqft_lot15;
public HouseModelJava(Long id, Date date, Double price, Integer bedrooms, Double bathrooms, Integer sqft_living,
Integer sqft_lot, Double floors, Integer waterfront, Integer view, Integer condition, Integer grade,
Integer sqft_above, Integer sqft_basement, Integer yr_built, Integer yr_renovated, String zipcode,
Double latitude, Double longitude, Integer sqft_living15, Integer sqft_lot15) {
super();
this.id = id;
this.date = date;
this.price = price;
this.bedrooms = bedrooms;
this.bathrooms = bathrooms;
this.sqft_living = sqft_living;
this.sqft_lot = sqft_lot;
this.floors = floors;
this.waterfront = waterfront;
this.view = view;
this.condition = condition;
this.grade = grade;
this.sqft_above = sqft_above;
this.sqft_basement = sqft_basement;
this.yr_built = yr_built;
this.yr_renovated = yr_renovated;
this.zipcode = zipcode;
this.latitude = latitude;
this.longitude = longitude;
this.sqft_living15 = sqft_living15;
this.sqft_lot15 = sqft_lot15;
}
public HouseModelJava(String... row) {
this(Long.parseLong(row[0]), new Date(parseDate(row[1])), Double.parseDouble(row[2]), Integer.parseInt(row[3]),
Double.parseDouble(row[4]), Integer.parseInt(row[5]), Integer.parseInt(row[6]), Double
.parseDouble(row[7]), Integer.parseInt(row[8]), Integer.parseInt(row[9]), Integer
.parseInt(row[10]), Integer.parseInt(row[11]), Integer.parseInt(row[12]), Integer
.parseInt(row[13]), Integer.parseInt(row[14]), Integer.parseInt(row[15]), row[16], Double
.parseDouble(row[17]), Double.parseDouble(row[18]), Integer.parseInt(row[19]), Integer
.parseInt(row[20]));
}
@Override
public LabeledPoint toLabeledPoint() {
return new LabeledPoint(label(), features());
}
@Override
public double label() {
return price;
}
@Override
public Vector features() {
double[] features = { id, bedrooms, bathrooms, sqft_living, sqft_lot, floors, waterfront, view, condition,
grade, sqft_above, sqft_basement, yr_built, yr_renovated, latitude, longitude, sqft_living15,
sqft_lot15 };
return Vectors.dense(features);
}
static Long parseDate(String value) {
try {
return new java.text.SimpleDateFormat("yyyyMMdd'T'hhmmss").parse(value).getTime();
} catch (ParseException e) {
throw new RuntimeException(e);
}
}
/* getters */
}
Function which creates linear regression model in Java 8 can look like that:
static LinearRegressionModel createLinearRegressionModel(JavaRDD<LabeledPoint> rdd, Integer numIterations,
Double stepSize) {
return LinearRegressionWithSGD.train(rdd.rdd(),
numIterations == null ? 100 : numIterations,
stepSize == null ? 0.01 : stepSize);
}
Function which creates decision tree model in Java 8 for regression can look like that:
static DecisionTreeModel createDecisionTreeRegressionModel(JavaRDD<LabeledPoint> rdd, Integer maxDepth,
Integer maxBins) {
String impurity = "variance";
return DecisionTree.trainRegressor(rdd,
Collections.emptyMap(),
impurity, maxDepth == null ? 10 : maxDepth,
maxBins == null ? 20 : maxBins);
}
And the rest of the code:
try (JavaSparkContext sc = new JavaSparkContext(configLocalMode(regressionApp()))) {
JavaRDD<String> hdFile = localFile("house-data.csv", sc);
JavaRDD<LabeledPoint> houses = hdFile.map(s -> s.split(",")).
filter(t -> !"date".equals(t[1]) ).
map(a -> new HouseModelJava(a).toLabeledPoint());
StandardScalerModel scaler = new StandardScaler(true, true).fit(houses.map(dp -> dp.features()).rdd());
JavaRDD<LabeledPoint>[] split = houses.map(
dp -> new LabeledPoint(dp.label(), scaler.transform(dp.features())))
.randomSplit(new double[] { .9, .1 }, 10204L);
JavaRDD<LabeledPoint> train = split[0].cache();
JavaRDD<LabeledPoint> test = split[1].cache();
DecisionTreeModel model = createDecisionTreeRegressionModel(train, null, null);
test.take(5)
.stream()
.forEach(
x -> System.out.println(String.format("Predicted: %.1f, Label: %.1f",
model.predict(x.features()), x.label()))
);
JavaPairRDD<Object, Object> predictionsAndValues = test.mapToPair(
p -> new Tuple2<Object, Object>(model.predict(p.features()), p.label())
);
System.out.println("Mean house price: " + test.mapToDouble(x -> x.label()).mean());
System.out.println("Max prediction error: "
+ predictionsAndValues.mapToDouble(
t2 -> Math.abs(Double.class.cast(t2._2) - Double.class.cast(t2._1))).max());
RegressionMetrics metrics = new RegressionMetrics(predictionsAndValues.rdd());
System.out.println(String.format("Mean Squared Error: %.2f", metrics.meanSquaredError()));
System.out.println(String.format("Root Mean Squared Error: %.2f", metrics.rootMeanSquaredError()));
System.out.println(String.format("Coefficient of Determination R-squared: %.2f", metrics.r2()));
System.out.println(String.format("Mean Absoloute Error: %.2f", metrics.meanAbsoluteError()));
System.out.println(String.format("Explained variance: %.2f", metrics.explainedVariance()));
}
}
Sample output can look like that:
Predicted: 481028,4, Label: 510000,0
Predicted: 732639,1, Label: 937000,0
Predicted: 437099,1, Label: 438000,0
Predicted: 632961,1, Label: 580500,0
Predicted: 328065,5, Label: 322500,0
Mean house price: 534198.5644402637
Max prediction error: 1710000.0
Mean Squared Error: 22243597500,87
Root Mean Squared Error: 149142,88
Coefficient of Determination R-squared: 0,81
Mean Absoloute Error: 87380,94
Explained variance: 110086143895,00