Drawing of ROC curve and PR curve

Evaluating the performance of classification models is an important topic in machine learning research. In this article, we will focus on two commonly used evaluation indicators, ROC curve and PR curve, introduce their principles and application methods, and explain how to use these indicators to evaluate performance. Analyze and compare.

1. Select the data set

1. Read the data set

Here, Scikit-learn’s make_classification() function is used to generate a manual classification data set, which contains 1000 samples and 10 features. test_size=0.2 means that 20% of the data set is used as the test set.

import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# Generate a classification problem data set
X, y = make_classification(n_samples=1000, n_features=10, random_state=10)

# Divide the data set into training set and test set
train_features, test_features, train_labels, test_labels = train_test_split(X, y, test_size=0.2, random_state=10)

print("Training set features:", train_features)
print("Training set labels:", train_labels)

Print dataset

2. Classify

Classification using implemented knn from sklearn library

from sklearn.neighbors import KNeighborsClassifier

#Create KNN classifier object
clf = KNeighborsClassifier(n_neighbors=5)

# Use the training data set for model training
clf.fit(train_features, train_labels)

# Use test data set for prediction
y_hat = clf.predict_proba(test_features)[:, 1]

print("Predicted value:",y_hat)
print("True value:",test_labels)

Print predicted and true values

2. Draw ROC curve and PR curve

1. Confusion matrix

Confusion matrix is a visual tool used to measure the performance of a classification model. Based on the confusion matrix, various evaluation indicators can be calculated, such as accuracy, recall, and precision, to further evaluate the performance and effect of the model.


TP, FP, TN, and FN are the four basic indicators used to measure the performance of the two-classification model:
TP (True Positive): Indicates the number of samples that are actually positive examples and are correctly judged as positive examples.
FP (False Positive): Indicates the number of samples that are actually negative examples but are incorrectly determined as positive examples.
TN (True Negative): Indicates the number of samples that are actually negative examples and are correctly judged as negative examples.
FN (False Negative): Indicates the number of samples that are actually positive examples but are incorrectly determined as negative examples.

2.ROC curve

The ROC curve is a graph composed of the true positive rate (TPR) and the false positive rate (FPR). In the two-classification problem, the model divides the observation data into positive and negative classes based on the observed data, and makes judgments based on the set threshold. For different thresholds, the model will get a set of TPR and FPR. Among them, TPR is also called recall rate, which represents the proportion of samples that are correctly judged as positive examples, that is, TPR = TP / (TP + FN). FPR represents the proportion of negative sample samples that are incorrectly determined as positive samples, that is, FPR = FP / (FP + TN).

3.PR curve

The PR curve is a metric used to evaluate the performance of a classification model. Similar to the ROC curve, the PR curve is also a curve describing the performance of a classification model under different thresholds. In a binary classification problem, precision represents the proportion of samples that are correctly determined as positive examples to all samples that are predicted to be positive examples, that is, Precision = TP / (TP + FP). The recall rate, also called the true positive rate, represents the proportion of samples that are correctly determined as positive samples to all actual positive samples, Recall = TP / (TP + FN).

4.The difference between ROC curve and PR curve

The goals are different: the ROC curve focuses on the trade-off relationship between the true positive rate (TPR) and the false positive rate (FPR), while the PR curve focuses on the trade-off relationship between the precision rate (Precision) and the recall rate (Recall).
The vertical axis and the horizontal axis are different: the vertical axis of the ROC curve is TPR (recall rate) and the horizontal axis is FPR, while the vertical axis of the PR curve is Precision and the horizontal axis is Recall.
The application scenarios are different: the ROC curve is suitable for dealing with problems where the ratio of positive and negative samples is very different, while the PR curve is suitable for dealing with problems where the ratio of positive and negative samples is close.
Different evaluation indicators: AUC (Area Under the Curve) can be calculated based on the ROC curve, which represents the comprehensive performance of the classifier under different thresholds. The average precision (Average Precision) can be calculated according to the PR curve, which represents the average performance of the classifier when the recall rate is different.

5. Draw ROC curve and PR curve

Calculate data points for ROC curve and PR curve

# Calculate data points of ROC curve and PR curve
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score

fpr, tpr, thresholds = roc_curve(test_labels, y_hat)
roc_auc = auc(fpr, tpr)

precision, recall, thresholds = precision_recall_curve(test_labels, y_hat)
pr_avg = average_precision_score(test_labels, y_hat)

Calculate the data points of the ROC curve using the roc_curve function, which accepts the true labels test_labels and the predicted score y_hat as input and returns FPR, TPR and corresponding threshold. At the same time, the AUC value of the ROC curve was calculated using the auc function. AUC is one of the commonly used indicators to evaluate the performance of a binary classification model. AUC is the area under the ROC curve. The value of AUC ranges from 0 to 1. The closer to 1, the better the model performance, and the closer to 0.5, the worse the model performance.

Use the precision_recall_curve function to calculate the data points of the PR curve. This function also accepts the real labels test_labels and the predicted score y_hat as input, and returns the precision rate, Recall rate and corresponding threshold. Use the average_precision_score function to calculate the average accuracy. The value ranges from 0 to 1. The higher the accuracy, the better the model performance.

6. Draw ROC curve and PR curve

# Draw ROC curve
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.05])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show()
print("AUC:",roc_auc)

# Draw PR curve
plt.figure()
plt.step(recall, precision, color='darkorange', lw=2, where='post')
plt.plot([0, 1], [1, 0], color='navy', lw=2, linestyle='--')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.05])
plt.title('Precision-Recall curve')
plt.show()
print("avg:", pr_avg)

Among them, plt.figure() creates a new graphics window.
plt.plot(fpr, tpr, color=’darkorange’, lw=2, label=f’ROC curve (AUC = {roc_auc:.2f})’) Draw the ROC curve. Among them, fpr and tpr represent the false positive rate and the true positive rate respectively, that is, the horizontal and vertical coordinates on the ROC curve. The color of the curve is set to orange (‘darkorange’) and the line width is set to 2. At the same time, the value of AUC is displayed in the label with two decimal places.
plt.plot([0, 1], [0, 1], color=’navy’, lw=2, linestyle=’–‘) draws a diagonal line.
plt.xlim([0.0, 1.05]) and plt.ylim([0.0, 1.05]) set the value range of the x-axis and y-axis to [0.0, 1.05] respectively.
plt.xlabel(‘False Positive Rate’) and plt.ylabel(‘True Positive Rate’) set the labels for the x-axis and y-axis.
plt.title(‘Receiver Operating Characteristic’) Sets the title of the graphic.

ROC curve

PR curve

The AUC value of 0.9028 indicates that the model has good classification ability, relatively high true positive rate and low false positive rate.

The average accuracy of 0.9016 indicates that the model has high overall accuracy when dealing with binary classification problems.

7. Compare knn with different k values

Loop through each n_neighbors value and perform the corresponding operation

for n in n_neighbors_values:
    clf = KNeighborsClassifier(n_neighbors=n)
    clf.fit(train_features, train_labels)
    y_hat = clf.predict_proba(test_features)[:, 1]
    
    fpr, tpr, _ = roc_curve(test_labels, y_hat)
    roc_auc = auc(fpr, tpr)
    
    precision, recall, _ = precision_recall_curve(test_labels, y_hat)
    pr_auc = average_precision_score(test_labels, y_hat)
    
    plt.subplot(2, 1, 1)
    plt.plot(fpr, tpr, lw=2, label=f'n_neighbors={n} (AUC = {roc_auc:.2f})')

    plt.subplot(2, 1, 2)
    plt.step(recall, precision, lw=2, where='post', label=f'n_neighbors={n} (AP = {pr_auc:.2f})')

The running results are as follows

It can be seen that when n_neighbors=7, both AUC and Average Precision reach the maximum value, indicating that the KNN classifier has a very stable classification effect on the test set and the accuracy of the classification results is high.

3. Summary

In this experiment, we explored the method of using the KNN model to classify a certain data set, divided the data set into a training set and a test set, used the KNN model to classify the test set, and calculated the performance such as AUC and Average Precision. index.

Through experiments, it was found that the KNN model has better classification performance on this data set. Next, we conducted some further experiments, adjusted the K value, and concluded that AUC and Average Precision are the largest when k=7, which means that the KNN model performs best on this data set. In short, this experiment enhanced my understanding and mastery of the KNN model and learned how to draw ROC and PR curves.