In the End-to-End Object Detection with Transformers paper, they directly predict \(N\) number of prediction boxes and treat them as set. To find the matching predicted boxes with the target boxes they use Hungarian matching algorithm. There is a great blogpost by Lei Mao explaining the basic concepts of Hungarian matching algorithm.

Short summary of the problem

In the case of DETR, we predict 100 boxes which is more than maximum number of boxes almost any image. Our task is to find closest* predicted box for each target box. Meaning, we will select \(n\) best prediction boxes among the \(m\) outputs.

To do this, we form a cost matrix \(C\) with the size \(m \times n\) , where \(m\) is the number of predictions and \(n\) is the number of targets where \(m > n\) . \(C_{i,j}\) would be the matching cost of prediction \(i\) and ground truth box \(j\) .

Matching cost

Matching cost of the element \(C_{ij}\) is given by:

\[\begin{equation} C_{ij} = \mathcal{L}_{iou}(b_i, \hat{b}_j) + ||b_i - \hat{b}_j||_1 - \hat{p}_j(c_i) \end{equation}\]

where \(\hat{p}_j(c_i)\) is the probability of the target class. After calculating the cost matrix \(C\), we can use linar sum assignment function from the SciPy package. It returns the row_ids and column_ids which corresponds to the (matched_predictions_ids, target_ids).


Following is the simplified version of the Hungarian matching algorithm used in the source code of DETR.

import torch
from scipy.optimize import linear_sum_assignment

target_labels = [1,2] # two target labels
target_bboxes = torch.rand(2,4) # two target boxes

pred_logits = torch.rand(10, 3) # 10 predictions for 3 labels
pred_bboxes = torch.rand(10, 4)  # 4 boxes for 10 predictions

class_cost = -pred_logits[:, target_labels] # 10x2
# We can use torch.cdist which returns the norm distance matrix 10x2
l1_cost = torch.cdist(pred_bboxes, target_bboxes, p=1) # 10x2
# To simplify we omit the IoU calculation. Look at
iou_cost = torch.randn(10,2) # 10x2

cost_matrix = class_cost + l1_cost + iou_cost
match_preds, match_targets = linear_sum_assignment(cost_matrix)
print(match_preds, match_targets)
#[2 9] [0 1]

In this example, predictions \(2,9\) matched with \(0,1\) target boxes, respecitvely : (2<->0), (9<->1)