-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathTestOnMovieLensDataset.java
129 lines (114 loc) · 5.36 KB
/
TestOnMovieLensDataset.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package com.indeed.vw.wrapper.integration.tests;
import com.indeed.vw.wrapper.api.example.ExampleBuilder;
import com.indeed.vw.wrapper.api.parameters.SGDVowpalWabbitBuilder;
import com.indeed.vw.wrapper.api.VowpalWabbit;
import com.indeed.vw.wrapper.api.parameters.Loss;
import com.indeed.vw.wrapper.integration.IntegrationSuite;
import com.indeed.vw.wrapper.progvalidation.Metrics;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* This is example of how vowpal-wabbit can be used for recommendation systems
*/
public class TestOnMovieLensDataset extends IntegrationSuite {
private static final int RATING_POS = 0;
private static final int TIMESTAMP_POS = 1;
private static final int USER_ID_POS = 2;
private static final int MOVIE_ID_POS = 3;
private static final int U_GENDER_POS = 4;
private static final int U_AGE_POS = 5;
private static final int U_OCUPATION_POS = 6;
private static final int U_ZIP_CODE_POS = 7;
private static final int M_TITLE_POS = 8;
private static final int M_GENRES_POS = 9;
@Override
protected SGDVowpalWabbitBuilder configureVowpalWabbit() {
return VowpalWabbit.builder()
// Bit precision increases consumption of RAM
// and decreases chances of hash collision - so improves quality.
// You can increase this parameter even more
.bitPrecision(22)
// Always try to play with adaptive, invariant and normalized.
.adaptive().invariant()
.lossFunction(Loss.squared)
.learningRate(0.15)
// Most of the magic happens here.
// LRQFA - low rank quadratic feature aware interactions.
// This option allows to learn latent interaction
// between user_id and movie_id.
// Notice that this option will learn interaction between user_id and movie_id
// even if concrete pair (user_id, movie_id) didn't occur in train set.
// For more details search for "feature-aware factorization machines"
.lrqfa("user_id", "movie_id", 7)
// LRQFA will learn interactions only for those users and movies that present in train set.
// Film features represent movie_id but they are much less sparse and it makes sense
// to generate quadratic features with them
.quadratic("user_id", "film_features")
.quadratic("demographics_features", "movie_id")
// Useful constraints.
.minPrediction(1)
.maxPrediction(5)
// Regularization term should be small in high dimension feature space
// otherwise you will go out of loss minimum.
.l2(0.000001);
}
private final Pattern YEAR_PATTERN = Pattern.compile("[^\\d]([12][90]\\d\\d)[^\\d]");
@Override
protected ExampleBuilder parseWvExample(final List<String> columns) {
final double rating = Double.parseDouble(columns.get(RATING_POS));
final ExampleBuilder exampleBuilder = ExampleBuilder.create()
.label(rating);
exampleBuilder.createNamespace("user_id")
.addCategoricalFeature(columns.get(USER_ID_POS));
exampleBuilder.createNamespace("movie_id")
.addCategoricalFeature(columns.get(MOVIE_ID_POS));
// This number is tuned through progressive validation
final double secondaryFeaturesWeight = 0.1;
exampleBuilder.createNamespace("demographics_features")
.namespaceWeight(secondaryFeaturesWeight)
.addCategoricalFeature("age", columns.get(U_AGE_POS))
.addCategoricalFeature("gender", columns.get(U_GENDER_POS))
.addCategoricalFeature("occupation", columns.get(U_OCUPATION_POS));
ExampleBuilder.NamespaceBuilder filmFeaturesNamespace = exampleBuilder.createNamespace("film_features")
.namespaceWeight(secondaryFeaturesWeight)
.addTextAsFeatures(columns.get(M_GENRES_POS))
.addTextAsFeatures(columns.get(M_TITLE_POS));
final Matcher yearMatcher = YEAR_PATTERN.matcher(columns.get(M_TITLE_POS));
if (yearMatcher.find()) {
final int year = Integer.parseInt(yearMatcher.group(1));
final String decade = (year / 10) + "0";
filmFeaturesNamespace
.addCategoricalFeature("creation_decade", decade);
}
return exampleBuilder;
}
@Override
protected String getMetricToVerify() {
return "RMSE";
}
@Override
protected Metrics createProgressiveValidation(final int printEveryN) {
return Metrics.regressionMetrics(printEveryN);
}
@Override
protected char getInputCsvSeparator() {
return ',';
}
@Override
protected double expectedTestScore() {
return 0.908;
}
@Override
protected String getTrainPath() {
return "/movie-lens-dataset/train.csv.gz";
}
@Override
protected String getTestPath() {
return "/movie-lens-dataset/test.csv.gz";
}
@Override
protected String getModelPath() {
return "/movie-lens-dataset/model.8.2.0.bin";
}
}