import
numpy as np
import
random
import
scipy.stats as ss
import
matplotlib.pyplot as plt
def
make_prediction_grid(predictors, outcomes, limits, h, k):
(x_min, x_max, y_min, y_max)
=
limits
xs
=
np.arange(x_min, x_max, h)
ys
=
np.arange(y_min, y_max, h)
xx, yy
=
np.meshgrid(xs, ys)
prediction_grid
=
np.zeros(xx.shape, dtype
=
int
)
for
i, x
in
enumerate
(xs):
for
j, y
in
enumerate
(ys):
p
=
np.array([x, y])
prediction_grid[j, i]
=
knn_predict(p, predictors, outcomes, k)
return
(xx, yy, prediction_grid)
def
plot_prediction_grid (xx, yy, prediction_grid, filename):
""" Plot KNN predictions
for
every point on the grid."""
from
matplotlib.colors
import
ListedColormap
background_colormap
=
ListedColormap (["hotpink", "lightskyblue", "yellowgreen"])
observation_colormap
=
ListedColormap (["red", "blue", "green"])
plt.figure(figsize
=
(
10
,
10
))
plt.pcolormesh(xx, yy, prediction_grid, cmap
=
background_colormap, alpha
=
0.5
)
plt.scatter(predictors[:,
0
], predictors [:,
1
], c
=
outcomes, cmap
=
observation_colormap, s
=
50
)
plt.xlabel(
'Variable 1'
); plt.ylabel(
'Variable 2'
)
plt.xticks(()); plt.yticks(())
plt.xlim (np.
min
(xx), np.
max
(xx))
plt.ylim (np.
min
(yy), np.
max
(yy))
plt.savefig(filename)
(predictors, outcomes)
=
generate_synth_data()
k
=
5
; filename
=
"knn_synth_5.pdf"; limits
=
(
-
3
,
4
,
-
3
,
4
); h
=
0.1
(xx, yy, prediction_grid)
=
make_prediction_grid(predictors, outcomes, limits, h, k)
plot_prediction_grid(xx, yy, prediction_grid, filename)
plt.show()