torchsig.models.spectrogram_models.detr.criterionΒΆ

Criterion and matching modules from Detectron2, Mask2Former, and DETR codebases

Functions

batch_dice_loss

Compute the DICE loss, similar to generalized IOU for masks :param inputs: A float tensor of arbitrary shape. The predictions for each example. :param targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class).

batch_sigmoid_ce_loss

calculate_uncertainty

We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the

dice_loss

Compute the DICE loss, similar to generalized IOU for masks :param inputs: A float tensor of arbitrary shape. The predictions for each example. :param targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class).

get_uncertain_point_coords_with_randomness

Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties

get_world_size

is_dist_avail_and_initialized

nested_tensor_from_tensor_list

point_sample

A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. Unlike :function:`torch.nn.functional.grid_sample` it assumes point_coords to lie inside [0, 1] x [0, 1] square. :param input: A tensor of shape (N, C, H, W) that contains features map on a H x W grid. :type input: Tensor :param point_coords: A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains :type point_coords: Tensor :param [0: :param 1] x [0: :param 1] normalized point coordinates.:.

sigmoid_ce_loss

Classes

HungarianMatcher

This class computes an assignment between the targets and the predictions of the network

NestedTensor

SetCriterion

This class computes the loss for DETR. The process happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and box).