Project | kNN | Classifying IRIS Dataset
Introduction | kNN Algorithm
Statistical learning refers to a collection of mathematical and computation tools to understand data.In what is often called supervised learning, the goal is to estimate or predict an output based on one or more inputs.The inputs have many names, like predictors, independent variables, features, and variables being called common.The output or outputs are often called response variables, or dependent variables.If the response is quantitative – say, a number that measures weight or height, we call these problems regression problems.If the response is qualitative– say, yes or no, or blue or green, we call these problems classification problems.This case study deals with one specific approach to classification.The goal is to set up a classifier such that when it’s presented with a new observation whose category is not known, it will attempt to assign that observation to a category, or a class, based on the observations for which it does know the true category.This specific method is known as the k-Nearest Neighbors classifier, or kNN for short.Given a positive integer k, say 5, and a new data point, it first identifies those k points in the data that are nearest to the point and classifies the new data point as belonging to the most common class among those k neighbors.
Aim: Build our very own k – Nearest Neighbor classifier to classify data from the IRIS dataset of scikit-learn.
Distance between two points
We are going to write a function, which will find the distance between two given 2-D points in the x-y plane.We will import numpy, to take help of numpy arrays for storing the coordinates.Finding the distance between two points will help in finding the nearest neighbor of the input point.
Majority vote counter
We will create a 3 x 3 matrix of points with the help of numpy array to build the environment of dispersed points in the plane.We will also create a function called majority_vote() to find the highest count/vote of a particular vote list, e.g ( 1, 2, 1, 1, 2, 3, 2, 2, 3, 1, 1, 2, 3, 3, 2, 3) etc.This is indirectly the mode of the given data, so can also be calculated with the help of scipy statistics module.We will create another function called majority_vote_short() which will perform the same functionality as majority_vote() but will make use of mode() from scipy.stats.Both these functions will be necessary in predicting the points later.
Our aim is to build a kNN classifier, so we need to develop an algorithm to find the nearest neighbours of a given set of points.Suppose we need to insert a point into x-y plane within an environment of given set of existing points.We will have to classify the point we wish to insert into one of the category of the existing points and then insert accordingly.So, we will build a function find_nearest_neighbours() to find the nearest neighbor of the given point.It will take in (i)The point we wish to insert (ii)set of existing points and (iii)k helps with the indices, as parameters to the function.We will visualize the situation by plotting the x-y plane filled with points with the help of matplotlib.
kNN Predict around Synthetic Data
After finding the nearest neighbors, we will have to predict the category of the input point.We will build a function called knn_predict() which will predict the category of the point we wish to insert.We can build another function called generate_synth_data() to generate synthetic points in the x-y plane.
kNN Prediction GRID
We will build a function called make_prediction_grid() which will make a grid and allot the different class of points in the grid.Another function plot_prediction_grid() must be created to plot the outputs of make_prediction_grid() using matplotlib.
Output: The plot shown here is a grid of two class, visually shown as pink and green.We tried to predict the class of the points based on their position and environment.The green points must fall in the green bricks of the grid and the red in the pink bricks of the grid.Look at the enlarged view to visually check the classifiers work.
Classifying the IRIS Dataset
We will test our classifier on a scikit learn dataset, called “IRIS”.For importing “IRIS”, we need to import datasets from sklearn and call the function datasets.load_iris().The “IRIS” dataset holds information on sepal length, sepal width, petal length & petal width for three different class of Iris flower – Iris-Setosa, Iris-Versicolour & Iris-Verginica.Based on the data from the dataset, we need to classify and visualize them using our classifier.The Sci-kit learn (sklearn) library already holds a pre built classifier.We will compare both the classifiers, [scikitlearn vs the one that we built] and check/compare prediction accuracy of both the classifier .
Output: It seems from the output that our classifier is actually performing better than the sklearn classifier.
This article is contributed by Amaryta Ranjan Saikia. If you like GeeksforGeeks and would like to contribute, you can also write an article using write.geeksforgeeks.org or mail your article to email@example.com. See your article appearing on the GeeksforGeeks main page and help other Geeks.
Please write comments if you find anything incorrect, or you want to share more information about the topic discussed above.