# Multiclass classification using scikit-learn

Multiclass classification is a popular problem in supervised machine learning.

**Problem – ** Given a dataset of **m** training examples, each of which contains information in the form of various features and a label. Each label corresponds to a class, to which the training example belongs to. In multiclass classification, we have a finite set of classes. Each training example also has **n** features.

For example, in the case of identification of different types of fruits, “Shape”, “Color”, “Radius” can be features and “Apple”, “Orange”, “Banana” can be different class labels.

In a multiclass classification, we train a classifier using our training data, and use this classifier for classifying new examples.

**Aim of this article – **We will use different multiclass classification methods such as, KNN, Decision trees, SVM, etc. We will compare their accuracy on test data. We will perform all this with sci-kit learn (Python). For information on how to install and use sci-kit learn, visit http://scikit-learn.org/stable/

**Approach – **

- Load dataset from source.
- Split the dataset into “training” and “test” data.
- Train Decision tree, SVM, and KNN classifiers on the training data.
- Use the above classifiers to predict labels for the test data.
- Measure accuracy and visualise classification.

**Decision tree classifier – **Decision tree classifier is a systematic approach for multiclass classification. It poses a set of questions to the dataset (related to its attributes/features). The decision tree classification algorithm can be visualized on a binary tree. On the root and each of the internal nodes, a question is posed and the data on that node is further split into separate records that have different characteristics. The leaves of the tree refer to the classes in which the dataset is split. In the following code snippet, we train a decision tree classifier in scikit-learn.

`# importing necessary libraries ` `from` `sklearn ` `import` `datasets ` `from` `sklearn.metrics ` `import` `confusion_matrix ` `from` `sklearn.model_selection ` `import` `train_test_split ` ` ` `# loading the iris dataset ` `iris ` `=` `datasets.load_iris() ` ` ` `# X -> features, y -> label ` `X ` `=` `iris.data ` `y ` `=` `iris.target ` ` ` `# dividing X, y into train and test data ` `X_train, X_test, y_train, y_test ` `=` `train_test_split(X, y, random_state ` `=` `0` `) ` ` ` `# training a DescisionTreeClassifier ` `from` `sklearn.tree ` `import` `DecisionTreeClassifier ` `dtree_model ` `=` `DecisionTreeClassifier(max_depth ` `=` `2` `).fit(X_train, y_train) ` `dtree_predictions ` `=` `dtree_model.predict(X_test) ` ` ` `# creating a confusion matrix ` `cm ` `=` `confusion_matrix(y_test, dtree_predictions) ` |

*chevron_right*

*filter_none*

**SVM (Support vector machine) classifier –**

SVM (Support vector machine) is an efficient classification method when the feature vector is high dimensional. In sci-kit learn, we can specify the the kernel function (here, linear). To know more about kernel functions and SVM refer – Kernel function | sci-kit learn and SVM.

`# importing necessary libraries ` `from` `sklearn ` `import` `datasets ` `from` `sklearn.metrics ` `import` `confusion_matrix ` `from` `sklearn.model_selection ` `import` `train_test_split ` ` ` `# loading the iris dataset ` `iris ` `=` `datasets.load_iris() ` ` ` `# X -> features, y -> label ` `X ` `=` `iris.data ` `y ` `=` `iris.target ` ` ` `# dividing X, y into train and test data ` `X_train, X_test, y_train, y_test ` `=` `train_test_split(X, y, random_state ` `=` `0` `) ` ` ` `# training a linear SVM classifier ` `from` `sklearn.svm ` `import` `SVC ` `svm_model_linear ` `=` `SVC(kernel ` `=` `'linear'` `, C ` `=` `1` `).fit(X_train, y_train) ` `svm_predictions ` `=` `svm_model_linear.predict(X_test) ` ` ` `# model accuracy for X_test ` `accuracy ` `=` `svm_model_linear.score(X_test, y_test) ` ` ` `# creating a confusion matrix ` `cm ` `=` `confusion_matrix(y_test, svm_predictions) ` |

*chevron_right*

*filter_none*

**KNN (k-nearest neighbours) classifier – **KNN or k-nearest neighbours is the simplest classification algorithm. This classification algorithm does not depend on the structure of the data. Whenever a new example is encountered, its k nearest neighbours from the training data are examined. Distance between two examples can be the euclidean distance between their feature vectors. The majority class among the k nearest neighbours is taken to be the class for the encountered example.

`# importing necessary libraries ` `from` `sklearn ` `import` `datasets ` `from` `sklearn.metrics ` `import` `confusion_matrix ` `from` `sklearn.model_selection ` `import` `train_test_split ` ` ` `# loading the iris dataset ` `iris ` `=` `datasets.load_iris() ` ` ` `# X -> features, y -> label ` `X ` `=` `iris.data ` `y ` `=` `iris.target ` ` ` `# dividing X, y into train and test data ` `X_train, X_test, y_train, y_test ` `=` `train_test_split(X, y, random_state ` `=` `0` `) ` ` ` `# training a KNN classifier ` `from` `sklearn.neighbors ` `import` `KNeighborsClassifier ` `knn ` `=` `KNeighborsClassifier(n_neighbors ` `=` `7` `).fit(X_train, y_train) ` ` ` `# accuracy on X_test ` `accuracy ` `=` `knn.score(X_test, y_test) ` `print` `accuracy ` ` ` `# creating a confusion matrix ` `knn_predictions ` `=` `knn.predict(X_test) ` `cm ` `=` `confusion_matrix(y_test, knn_predictions) ` |

*chevron_right*

*filter_none*

**Naive Bayes classifier – **Naive Bayes classification method is based on Bayes’ theorem. It is termed as ‘Naive’ because it assumes independence between every pair of feature in the data. Let **(x _{1}, x_{2}, …, x_{n})** be a feature vector and

**y**be the class label corresponding to this feature vector.

Applying Bayes’ theorem,

Since, **x _{1}, x_{2}, …, x_{n}** are independent of each other,

Inserting proportionality by removing the **P(x _{1}, …, x_{n})** (since, it is constant).

Therefore, the class label is decided by,

**P(y)** is the relative frequency of class label **y** in the training dataset.

In case of Gaussian Naive Bayes classifier, **P(x _{i} | y)** is calculated as,

`# importing necessary libraries ` `from` `sklearn ` `import` `datasets ` `from` `sklearn.metrics ` `import` `confusion_matrix ` `from` `sklearn.model_selection ` `import` `train_test_split ` ` ` `# loading the iris dataset ` `iris ` `=` `datasets.load_iris() ` ` ` `# X -> features, y -> label ` `X ` `=` `iris.data ` `y ` `=` `iris.target ` ` ` `# dividing X, y into train and test data ` `X_train, X_test, y_train, y_test ` `=` `train_test_split(X, y, random_state ` `=` `0` `) ` ` ` `# training a Naive Bayes classifier ` `from` `sklearn.naive_bayes ` `import` `GaussianNB ` `gnb ` `=` `GaussianNB().fit(X_train, y_train) ` `gnb_predictions ` `=` `gnb.predict(X_test) ` ` ` `# accuracy on X_test ` `accuracy ` `=` `gnb.score(X_test, y_test) ` `print` `accuracy ` ` ` `# creating a confusion matrix ` `cm ` `=` `confusion_matrix(y_test, gnb_predictions) ` |

*chevron_right*

*filter_none*

**References –**

- http://scikit-learn.org/stable/modules/naive_bayes.html
- https://en.wikipedia.org/wiki/Multiclass_classification
- http://scikit-learn.org/stable/documentation.html
- http://scikit-learn.org/stable/modules/tree.html
- http://scikit-learn.org/stable/modules/svm.html#svm-kernels
- https://www.analyticsvidhya.com/blog/2015/10/understaing-support-vector-machine-example-code/

This article is contributed by **Arik Pamnani**. If you like GeeksforGeeks and would like to contribute, you can also write an article using contribute.geeksforgeeks.org or mail your article to contribute@geeksforgeeks.org. See your article appearing on the GeeksforGeeks main page and help other Geeks.

Please write comments if you find anything incorrect, or you want to share more information about the topic discussed above.

## Recommended Posts:

- ML | Classification vs Regression
- Getting started with Classification
- ML | Why Logistic Regression in Classification ?
- Python | Image Classification using keras
- ML | Using SVM to perform classification on a non-linear dataset
- Regression and Classification | Supervised Machine Learning
- ML | Logistic Regression v/s Decision Tree Classification
- ML | Cancer cell classification using Scikit-learn
- Basic Concept of Classification (Data Mining)
- Design Patterns : A Must Skill to have for Software Developers in 2019
- Advantages of cracking GATE from Indian Perspective
- How Did Facebook Remove 2.2 Billion Fake Accounts in the First Quarter of 2019?
- 5 Must Have Tools For Web Application Penetration Testing
- 7 Tips and Tricks to Learn Programming Faster