import
java.util.Arrays;
import
java.util.List;
public
class
Main {
static
List<List<Integer>> features = Arrays.asList(Arrays.asList(-
1
, -
1
), Arrays.asList(-
1
,
1
), Arrays.asList(
1
, -
1
), Arrays.asList(
1
,
1
));
static
List<Integer> labels = Arrays.asList(-
1
,
1
,
1
,
1
);
static
double
[] weight = {
0.5
,
0.5
};
static
double
bias =
0.1
;
static
double
learning_rate =
0.2
;
static
int
epoch =
10
;
public
static
void
main(String[] args) {
for
(
int
i =
0
; i < epoch; i++) {
System.out.println(
"Epoch: "
+ (i+
1
));
double
sum_squared_error =
0.0
;
for
(
int
j =
0
; j < features.size(); j++) {
int
actual = labels.get(j);
int
x1 = features.get(j).get(
0
);
int
x2 = features.get(j).get(
1
);
double
unit = (x1 * weight[
0
]) + (x2 * weight[
1
]) + bias;
double
error = actual - unit;
System.out.println(
"Error = "
+ error);
sum_squared_error += error * error;
weight[
0
] += learning_rate * error * x1;
weight[
1
] += learning_rate * error * x2;
bias += learning_rate * error;
}
System.out.println(
"Sum of squared error = "
+ sum_squared_error/
4
+
"\n"
);
}
}
}