predict_proba for classification problem in Python
In this tutorial, we’ll see the function predict_proba for classification problem in Python. The main difference between predict_proba() and predict() methods is that predict_proba() gives the probabilities of each target class. Whereas, predict() gives the actual prediction as to which class will occur for a given set of features.
Importing our classifier
The classifier we’ll use for this is LogisticRegression from sklearn.linear_model. We then create our LogisticRegression model m.
from sklearn.linear_model import LogisticRegression m=LogisticRegression()
Getting our dataset
The dataset we’re using for this tutorial is the famous Iris dataset which is already uploaded in the sklearn.datasets module.
from sklearn.datasets import load_iris iris=load_iris()
Now, let’s take a look at the dataset’s features and targets.
iris.feature_names iris.target_names
Output:
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')
Splitting our data
The next step is to split our data into the training set and testing set. For this, we import the train_test_split() from sklearn.model_selection module.
from sklearn.model_selection import train_test_split X=iris.data y=iris.target Xtrain,Xtest,ytrain,ytest=train_test_split(X,y,test_size=0.1)
Now, we’ll take a look at the shape of our resulting training data.
print(Xtrain.shape) print(Xtest.shape)
Output:
Training our model
Since we have split our dataset, it is now time for us to train our model using the fit() method and print its accuracy.
m.fit(Xtrain,ytrain) print(m.score(Xtest,ytest))
Output:
1.0
As you can see, we got an accuracy score of 1.0, which is perfect! Yay!
Using predict_proba
Now, let’s see what happens when we call predict_proba. For more information on the predict_proba method, visit its documentation.
m.predict_proba(Xtest)
Output:
array([[8.29639556e-01, 1.70346663e-01, 1.37808397e-05], [8.48022771e-01, 1.51903019e-01, 7.42102237e-05], [2.15082716e-03, 4.19671627e-01, 5.78177546e-01], [1.08867316e-02, 7.12889122e-01, 2.76224146e-01], [2.06046308e-04, 2.66292366e-01, 7.33501588e-01], [8.77741863e-01, 1.22250469e-01, 7.66768013e-06], [4.46856465e-03, 3.53529407e-01, 6.42002028e-01], [8.03924450e-01, 1.96012309e-01, 6.32412272e-05], [9.09784658e-01, 9.02012752e-02, 1.40667886e-05], [2.96751485e-04, 2.92144656e-01, 7.07558593e-01], [9.74437252e-04, 3.46964072e-01, 6.52061491e-01], [3.56926619e-03, 3.60715696e-01, 6.35715037e-01], [8.76114455e-01, 1.23877298e-01, 8.24653734e-06], [8.75120615e-01, 1.24837439e-01, 4.19457555e-05], [7.58789806e-01, 2.41162916e-01, 4.72776226e-05]])
This output gives the probabilities of the occurrence of each target every tuple of the testing set.
To make things more clear, let’s predict the targets of the testing set using our normal predict() method.
ypred = m.predict(Xtest) ypred
Output:
You can verify this by comparing the outputs of both the methods. You can also see the error in the prediction by comparing it with the actual ytest values.
Leave a Reply