在训练的时候由于刚开始进行验证的时候会在同一张图片中生成较多矩形检测框,因此要计算的矩阵会变得非常大,非常容易把显存撑爆,这里提供几种解决思路:

  1. 把图片大小改小:在训练时指定imgsz=512或者更小,但是缺点就是精度可能降低以及后续的训练可能会继续出现显存爆炸的问题。

  2. (推荐)修改计算矩阵的函数,把较大的矩阵分片计算:

# # 原始函数
# def batch_probiou(obb1, obb2, eps=1e-7):
#     """
#     Calculate the probabilistic IoU between oriented bounding boxes.

#     Args:
#         obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
#         obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
#         eps (float, optional): A small value to avoid division by zero.

#     Returns:
#         (torch.Tensor): A tensor of shape (N, M) representing obb similarities.

#     References:
#         https://arxiv.org/pdf/2106.06072v1.pdf
#     """
#     obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1
#     obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2

#     x1, y1 = obb1[..., :2].split(1, dim=-1)
#     x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1))
#     a1, b1, c1 = _get_covariance_matrix(obb1)
#     a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))

#     t1 = (
#         ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)
#     ) * 0.25
#     t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5
#     t3 = (
#         ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))
#         / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)
#         + eps
#     ).log() * 0.5
#     bd = (t1 + t2 + t3).clamp(eps, 100.0)
#     hd = (1.0 - (-bd).exp() + eps).sqrt()
#     return 1 - hd
def batch_probiou(obb1, obb2, eps=1e-7, chunk_size=10000):
    """
    Calculate the probabilistic IoU between oriented bounding boxes with chunking.

    Args:
        obb1 (torch.Tensor | np.ndarray): Ground truth obbs (N, 5) in xywhr format.
        obb2 (torch.Tensor | np.ndarray): Predicted obbs (M, 5) in xywhr format.
        eps (float): Small value to avoid division by zero.
        chunk_size (int): Maximum chunk size for dimension to prevent memory overflow.

    Returns:
        (torch.Tensor): IoU matrix of shape (N, M).
    """
    obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1
    obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2

    N, M = obb1.shape[0], obb2.shape[0]

    # Handle empty inputs
    if N == 0 or M == 0:
        return torch.zeros((N, M), device=obb1.device)

    # Check if chunking is needed
    if N <= chunk_size and M <= chunk_size:
        # Original calculation for small matrices
        x1, y1 = obb1[..., :2].split(1, dim=-1)
        x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1))
        a1, b1, c1 = _get_covariance_matrix(obb1)
        a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))

        t1 = ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / (
            (a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps) * 0.25
        t2 = ((c1 + c2) * (x2 - x1) * (y1 - y2)) / (
            (a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps) * 0.5
        t3 = (((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2)) / 
             (4 * ((a1 * b1 - c1.pow(2)).clamp(0) * (a2 * b2 - c2.pow(2)).clamp(0)).sqrt() + eps) + eps).log() * 0.5

        bd = (t1 + t2 + t3).clamp(eps, 100.0)
        hd = (1.0 - (-bd).exp() + eps).sqrt()
        return 1 - hd
    else:
        # Chunked calculation for large matrices
        device = obb1.device
        result = torch.zeros((N, M), device=device)

        # Process chunks of obb1
        for i in range(0, N, chunk_size):
            i_end = min(i + chunk_size, N)
            obb1_chunk = obb1[i:i_end]

            # Process chunks of obb2
            for j in range(0, M, chunk_size):
                j_end = min(j + chunk_size, M)
                obb2_chunk = obb2[j:j_end]

                # Recursive call for smaller chunk
                chunk_result = batch_probiou(obb1_chunk, obb2_chunk, eps, chunk_size)
                result[i:i_end, j:j_end] = chunk_result.to(device)

        return result

上面的修改可以把比较大的矩阵分成10000x10000的片来分开计算,然后合并结果,这样就不会爆显存了,但是计算速度会变慢,经过实际测试发现没事,只有前面的几十次epoches会这样,后面精度上来之后就不会有比较多的预测框了,也就不会有比较大的矩阵了,速度会快起来,10000x10000可以按照自己的显卡来调节,更改代码第一行chunk_size=10000就行。

  1. 直接换显卡!哈哈哈哈