Two methods of image retrieval to draw PR curves (according to requirements) or two methods of generating indicators

Table of Contents

  • Judging the best case based on the length of the different hash codes
    • Average precision and recall over all samples at each hash code length
    • The average precision and recall of each retrieval set sample under different hash code lengths
    • the difference
      • code difference
      • Differences in Application Scenarios
      • The meaning of the data points corresponding to the two codes
      • If we want to plot the performance of the retrieval set with a fixed hash code length, we should choose the second code.
      • Suitable usage scenarios for the first code

Judge the best case based on the length of different hash codes

Average precision and recall of all samples under each hash code length

def pr_curve(qB, rB, query_label, retrieval_label):
"Only accept binary codes consisting of 0 and 1 or -1 and 1, please perform onehot encoding on all labels"
    qB[qB==-1] = 0 ; rB[rB==-1] = 0 # Changing the value according to your hash code length will
    num_query = qB. shape[0]
    num_bit = qB. shape[1]
    P = torch.zeros(num_query, num_bit + 1)
    R = torch.zeros(num_query, num_bit + 1)
    for i in range(num_query):
        gnd = (query_label[i].unsqueeze(0).mm(retrieval_label.t()) > 0).float().squeeze()
        print('gnd size is ' + str(gnd. shape))
        '''For a matrix of (num,4) and (4,num) size to multiply and finally get a matrix of (num,num) size, he will show whether the labels are the same'''
        # The size of the groudtruth label here is (retrieval_label,)
        tsum = torch.sum(gnd) # How many identical tags there are (in fact, there is no fixed length of retrieval hash code)
        if tsum == 0:
            continue # If there is no similarity, there is no need to continue
        hamm = calc_hamming_dist(qB[i, :], rB) # Hamming distance
        print('hamming distance is ' + str(hamm.shape))
        tmp = (hamm <= torch.arange(0, num_bit + 1).reshape(-1, 1).float().to(hamm.device)).float()
        print('tmp size is ' + str(tmp. shape))
        '''
        Comparing the calculated Hamming distance with a specific hash code length, if it exceeds the current hash code length, it is False
        Explanation: Take the hash code of a query set, calculate the Hamming distance from all retrieval sets, and then use the tmp variable to store the performance under a specific hash code length (limited by the maximum threshold)
            The list generated by tmp is (num_bits + 1, retrieval_label) either True or False
        '''
        # The maximum threshold that can be detected under the specified hash code length - what is the length of the hash code (that is, the maximum number of retrieval sets that can be detected under the limit of a specific hash code length
        total = tmp.sum(dim=-1) # Dimensionality reduction 1-dimensional list, corresponding to each hash code length corresponding to retrieve the correct integer
        print('total size is' + str(total. shape))
        total = total + (total == 0).float() * 0.1 # elements that will be 0 become 0.1
        t = tmp * gnd # t is the actual number that can be accurately detected by the current hash code length shape(num_bits,) -- correct prediction
        print('t size is ' + str(t.shape))
        count = t.sum(dim=-1) # The total number of all images that can be effectively retrieved by the current certain hash code length
        '''
        To emphasize, the meaning of the variables here:
        1, count is generated by t, which represents the number of positive samples that can be correctly matched under the current hash code length
        2, t is the actual number that can be correctly detected by the current hash code length, t is smaller than gnd, because it represents the groud truth label that can be recognized by the precondition of the current hash code retrieval
        3, total corresponds to the maximum number of matches that can be matched under each specific hash code length, including positive samples and negative samples
        4, tsum can be the correct total number (correct answer) that is all groud truth labels
        '''
        p = count / total # precision (precision) = (the number of valid retrievals under all specific hash code lengths) / (the number of correct ones under each specific hash code) - correct prediction / total number of samples - Overall performance of the model
        print('p shape is ' + str(p.shape))

        r = count / tsum # correct predictions / all correct samples
        print('r shape is ' + str(r.shape))
        P[i] = p # The index of Persion corresponding to the first query label
        R[i] = r # and so on
    print(f'P size is {<!-- -->str(P.shape)}')
    print(f'R size is {<!-- -->str(R.shape)}')
    mask = (P > 0).float().sum(dim=0) # only consider the query set containing positive samples
    mask = mask + (mask == 0). float() * 0.1
    P = P.sum(dim=0) / mask
    R = R.sum(dim=0) / mask
    '''
    num_bit indicates different hash code lengths, ranging from 0 to num_bit
    For each query, calculate its precision (P) and recall (R) under different hash code lengths
    After the calculation is completed, use the mask to only consider non-zero values, and find the average P and average R of all queries at each hash code length
    '''
# Here is the visualization
    plt.plot(R, P, linestyle="-", marker='D', color='blue', label = 'DSH')
    plt.text(0.5, -0.1, '(a) PR curve @ 16bits', ha='center', va='center', fontsize=16, fontweight='bold', transform =plt.gca().transAxes)
    plt. grid(True)
    plt. xlim(0, 1)
    plt.ylim(0, 1)
    # resize the image
    fig = plt.gcf()
    fig. set_size_inches(9,9)
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.legend() # add legend
    plt. show()

    return P, R

The above function is used to calculate the average precision and recall of all samples under each hash code length

The average precision and recall rate of each retrieval set sample under different hash code lengths

The advantage of this method is that it is very robust and can accept any type of hash code

def plot_pr_curve(query_binary, retrieval_binary, query_label, retrieval_label):
    """
    Draw the precision-recall curve (PR curve) of retrieval evaluation
    Args:
        query_binary (numpy.ndarray): a numpy array of size (num_query, num_bit),
            Contains the binary hash code of the query image.
        retrieval_binary (numpy.ndarray): a numpy array of size (num_retrieval, num_bit),
            Contains the binary hash code of the retrieved image.
        query_label (numpy.ndarray): a numpy array of size (num_query,),
            Contains the ground-truth labels of the query image.
        retrieval_label (numpy.ndarray): a numpy array of size (num_retrieval,),
            Contains the ground truth labels for the retrieved images.

    Returns:
        none
    """
    # Convert the labels to int32 type to avoid indexing errors
    ## Convert label to int32 type to avoid indexing errors
    query_label = query_label.astype(np.int32)
    retrieval_label = retrieval_label.astype(np.int32)

    # Calculate the Hamming distances between query and retrieval binary codes
    # Calculate the Hamming distance between the query image and the retrieved image
    hamming_dist = np.count_nonzero(query_binary[:, np.newaxis, :]
                                    != retrieval_binary[np.newaxis, :, :], axis=2)
    '''
    query_binary and retrieval_binary are binary vectors with shapes (m, n) and (p, n) respectively, where n is the vector dimension.
    query_binary[:, np.newaxis, :] Change the shape of the query vector to (m, 1, n), insert the second dimension and a 1 dimension.
    retrieval_binary[np.newaxis, :, :] Change the shape of the index vector to (1, p, n), insert the first dimension and a 1 dimension.
    Then use != to calculate the number of elements with unequal positions in the two three-dimensional matrices, and the resulting shape is (m, p).
    np.count_nonzero() calculates the number of elements where Axis=2 is not 0, that is, the Hamming distance of each two-dimensional (m, p) corresponding position.
    '''
    print(hamming_dist)
    # Sort the retrieval samples by ascending order of Hamming distance
    # Sort the retrieved images by Hamming distance
    idx = np.argsort(hamming_dist, axis=1) # Sort according to the index, the index with the smallest Hamming distance is sent first

    # Initialize the precision-recall arrays
    # Initialize precision and recall arrays
    num_query = query_binary. shape[0]
    num_retrieval = retrieval_binary. shape[0]
    precision = np. zeros((num_query, num_retrieval))
    recall = np.zeros((num_query, num_retrieval))

    # Compute the precision-recall values for each query sample
    # Compute precision and recall for each query image
    for i in range(num_query):
        # Compute the ground-truth labels for the retrieval samples
        # Compute the ground truth label for the retrieved image
        gnd = (query_label[i] == retrieval_label[idx[i]])
        # Compute the cumulative sums of true positives and false positives
        # Compute the cumulative sum of true positives and false positives
        tp_cumsum = np. cumsum(gnd)
        fp_cumsum = np.cumsum(~gnd) # ~ means complement, negate
        # Compute the precision and recall values
        # Calculate precision and recall
        precision[i] = tp_cumsum / (tp_cumsum + fp_cumsum) # Proportion of positive samples to total samples
        recall[i] = tp_cumsum / np.count_nonzero(gnd) # Proportion of detected positive samples to all positive samples
    # Compute the mean precision and recall values over all queries
    # Compute the average precision and recall for all query images
    mean_precision = np.mean(precision, axis=0)
    mean_recall = np.mean(recall, axis=0)

    # Plot the precision-recall curve
    plt.plot(mean_recall, mean_precision, 'b-')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt. show()

Difference

Differences in code

There are the following main differences in the way the two pieces of code calculate the PR curve:

There are different ways to calculate the Hamming distance:
The first piece of code directly calculates the Hamming distance between qB and rB.

The second piece of code first increases the dimensions of qB and rB, and then compares the number of elements whose corresponding positions are not equal to obtain the Hamming distance.

Sorting retrieves samples differently:
The first piece of code doesn’t give the details of the sort.

The second piece of code explicitly sorts the retrieved samples in ascending order of Hamming distance.

There are different ways to count positive and negative samples:
The first piece of code records whether each query matches all retrieved samples through the gnd matrix.

The second piece of code compares whether the query and retrieval labels are the same to determine whether they match.

The way to calculate precision and recall is different:
The first piece of code directly calculates the count and total obtained in the whole process to get the precision.

The second piece of code gradually accumulates true positive and false positive to gradually calculate precision and recall.

The average precision and recall are calculated in the same way:
All queries with non-zero precision are considered through the mask, thereby calculating the average precision and recall of all queries.
To sum up, there is a big difference in the calculation process of the two codes, which is mainly reflected in:

Calculation of Hamming distance
sorting of retrieved samples
How to match queries and retrieve
The method of calculating precision and recall at each step
Therefore, there is a big difference in the calculation algorithm of the PR curve between the two codes, but the ultimate purpose is the same, and both get the average precision and recall.

Differences in application scenarios

First code:

The calculation is the average precision and recall of all query images under each hash code length.
The number of data points generated is: hash code length + 1.
Applicable scenarios: Evaluate the overall performance under different hash code lengths.
The second piece of code:

What is calculated is the precision and recall curves of each query image under different hash code lengths.
The number of data points generated is: the number of samples in the retrieval set.
Then take the average to get the average precision and recall rate under each hash code length.
Applicable scenarios: analyze the performance of different query images under different hash code lengths.
Application scenario:

If you only focus on the average PR curve under different hash code lengths, the first code is enough.

If you want to analyze the performance difference of different query images under different lengths, the second piece of code is more useful. It can be found whether there is a difference in the appropriate hash code length for different categories of images.

In general:

The first piece of code calculates the average PR curve of all query images under each hash code length, which is suitable for evaluating the overall effect of different hash code lengths.

The second piece of code calculates the PR curve of each query image under different hash code lengths, and then calculates the average, which is suitable for analyzing the hash length sensitivity of different query images.

The meaning of the data points corresponding to the two codes

The PR curves drawn by these two pieces of code actually correspond to different data points.

Specifically:

The first piece of code calculates the average precision and recall of all query images for each hash code length.
this means:

Corresponding to each hash code length, it generates a data point (P, R), which represents the average precision and recall rate of all query images under this hash code length.

Finally, a total of num_bit + 1 data points are generated, corresponding to num_bit + 1 hash code length from 0 to num_bit.

So this num_bit + 1 data point reflects: as the hash code length changes, the overall model performance changes.

And the second piece of code calculates:

Precision and recall curves for each query image at different hash code lengths.
this means:

It generates 1 data point (P, R) for each query image and each hash code length.

Ultimately, the same num_bit + 1 data points will be produced for each query image.

Then take the average to get the average precision and recall of all query images under each hash code length.

So this num_bit + 1 data point reflects: the performance of different query images under their respective hash lengths.

So the meaning of the data points generated by the two is different:

The first method pays more attention to the trend of the overall performance of the model as the hash code length changes.
The second method pays more attention to the performance of each query image, and finally obtains the overall trend by averaging.
In general, my understanding is that the PR curves drawn by the two codes actually correspond to different data points. The information they reflect is different, and their emphasis is different.

If we want to plot the performance of the retrieval set under a fixed hash code length, we should choose the second code.

If you need to draw the PR curve under a fixed number of hash code lengths (16, 32, 64, 128).

The second way calculates the precision and recall for each query image and takes the average. This more clearly reflects the performance of fixed hash code length.

A PR point corresponding to the second method is based on the number of samples in the retrieval set. So for a fixed hash code length, it can make full use of the data and describe the PR curve more accurately.

The PR point corresponding to the first method is based on the length of the hash code. For a fixed hash code length, it only generates 1 PR point. So the amount of information is relatively small.

So in summary:

If you want to draw the PR curve under a fixed number of hash code lengths, the second way to calculate the PR curve of each query image is more appropriate.
Because it can more accurately describe the PR curve under the fixed hash code length based on the overall situation of the retrieval set sample.
The first method only generates 1 PR point, and the amount of information is relatively small.
I hope my analysis will help you choose the feature code.

  • Summary of main points:

Your requirement is to draw the PR curve under a fixed number of hash code lengths.
The second way is to draw the PR curve of each query image and take the average, which is more suitable for the fixed hash code length.
Because it is based on the entire retrieval set sample, it has more information and can describe a more accurate PR curve.
Compared with the first method, only one PR point is generated, and the amount of information is relatively small.

Appropriate usage scenarios for the first type of code

There are two main possible applicable scenarios for the first code:

  • Evaluate the overall effect of different hash code lengths
    By calculating the average precision and recall of all query samples under each hash code length, we can evaluate the change of the overall performance of the model as the hash code length increases.

  • Compare the effect of different models under different hash code lengths
    In the case of keeping other parameters constant, such as using the same data set, the same query set and retrieval set, it is possible to compare the average PR curves of different models (such as DSH, ITQ) under different hash code lengths to evaluate them. the overall effect.

  • The characteristics of the first code block are:

The calculation is the average PR value of all query samples under each hash code length
The number of data points generated is the hash code length + 1
It reflects the change of the overall model performance as the hash code length changes
Therefore, it is mainly suitable for scenarios where it is necessary to evaluate how the overall effect of the model changes.

  • And the second piece of code is characterized by:

Calculate the PR value of each query sample under different hash code lengths, and then take the average
The number of data points is the number of retrieved samples
It reflects the performance of different query samples under different hash code lengths
It is more suitable for analyzing the difference of different query samples under different hash code lengths.