KNN Classification using Scikit-Learn in Python
Today we’ll learn KNN Classification using Scikit-learn in Python.
KNN stands for K Nearest Neighbors. The KNN Algorithm can be used for both classification and regression problems. KNN algorithm assumes that similar categories lie in close proximity to each other.
Thus, when an unknown input is encountered, the categories of all the known inputs in its proximity are checked. The category/class with the most count is defined as the class for the unknown input.
The algorithm first calculates the distances between the unknown point and all the points in the graph. It then takes the closest k points. The value of k can be determined by us. The categories of these k points then determine the category of our unknown point.
So let’s start coding!
Importing Libraries:
The first library that we import from sklearn is our dataset that we are going to work with. I chose the wine dataset because it is great for a beginner. You can also look at the datasets provided by sklearn or import your own dataset.
The next import is the train_test_split to split the dataset we got to a testing set and a training set.
Following this, we’ll import the KNN library itself.
Lastly, we import the accuracy_score to check the accuracy of our KNN model.
from sklearn.datasets import load_wine from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsClassifier from sklearn.metrics import accuracy_score
Loading the dataset:
Now after finishing importing our libraries, we load our dataset. Our dataset can be loaded by calling “load_<dataset_name>()” and creating a bunch object. In this case, our bunch object is “wine”.
wine=load_wine()
We can now check the sample data and shape of the data present in wine bunch object using wine.data and wine.shape respectively.
print(wine.data) print(wine.data.shape)
Output:
[[1.423e+01 1.710e+00 2.430e+00 ... 1.040e+00 3.920e+00 1.065e+03] [1.320e+01 1.780e+00 2.140e+00 ... 1.050e+00 3.400e+00 1.050e+03] [1.316e+01 2.360e+00 2.670e+00 ... 1.030e+00 3.170e+00 1.185e+03] ... [1.327e+01 4.280e+00 2.260e+00 ... 5.900e-01 1.560e+00 8.350e+02] [1.317e+01 2.590e+00 2.370e+00 ... 6.000e-01 1.620e+00 8.400e+02] [1.413e+01 4.100e+00 2.740e+00 ... 6.100e-01 1.600e+00 5.600e+02]] (178, 13)
Now we know that our data consists of 178 entries and 13 columns. The columns are called features that decide the corresponding input belongs to which class. The class here is called a target. So, we can now check the targets, target names and feature names.
print(wine.target) print(wine.target_names) print(wine.feature_names)
Output:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
['class_0' 'class_1' 'class_2']
Leave a Reply