Source code for torchsig.models.spectrogram_models.detr.criterion

"""
Criterion and matching modules from Detectron2, Mask2Former, and DETR codebases
"""
from typing import List, Optional, Tuple

import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchvision
from scipy.optimize import linear_sum_assignment
from torch import Tensor, nn
from torch.cuda.amp import autocast

from .utils import _max_by_axis


[docs] def get_world_size() -> int: if not dist.is_available(): return 1 if not dist.is_initialized(): return 1 return dist.get_world_size()
[docs] def get_uncertain_point_coords_with_randomness( coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio ): """ Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties are calculated for each point using 'uncertainty_func' function that takes point's logit prediction as input. See PointRend paper for details. Args: coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for class-specific or class-agnostic prediction. uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that contains logit predictions for P points and returns their uncertainties as a Tensor of shape (N, 1, P). num_points (int): The number of points P to sample. oversample_ratio (int): Oversampling parameter. importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. Returns: point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P sampled points. """ assert oversample_ratio >= 1 assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 num_boxes = coarse_logits.shape[0] num_sampled = int(num_points * oversample_ratio) point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) point_logits = point_sample(coarse_logits, point_coords, align_corners=False) # It is crucial to calculate uncertainty based on the sampled prediction value for the points. # Calculating uncertainties of the coarse predictions first and sampling them for points leads # to incorrect results. # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. # However, if we calculate uncertainties for the coarse predictions first, # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. point_uncertainties = uncertainty_func(point_logits) num_uncertain_points = int(importance_sample_ratio * num_points) num_random_points = num_points - num_uncertain_points idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) idx += shift[:, None] point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( num_boxes, num_uncertain_points, 2 ) if num_random_points > 0: point_coords = torch.cat( [ point_coords, torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), ], dim=1, ) return point_coords
[docs] def point_sample(input, point_coords, **kwargs): """ A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside [0, 1] x [0, 1] square. Args: input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains [0, 1] x [0, 1] normalized point coordinates. Returns: output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains features for points in `point_coords`. The features are obtained via bilinear interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. """ add_dim = False if point_coords.dim() == 3: add_dim = True point_coords = point_coords.unsqueeze(2) output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) if add_dim: output = output.squeeze(3) return output
[docs] def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True
[docs] def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): # TODO make this more general if tensor_list[0].ndim == 3: if torchvision._is_tracing(): # nested_tensor_from_tensor_list() does not export well to ONNX # call _onnx_nested_tensor_from_tensor_list() instead return _onnx_nested_tensor_from_tensor_list(tensor_list) # TODO make it support different-sized images max_size = _max_by_axis([list(img.shape) for img in tensor_list]) # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, h, w = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], : img.shape[2]] = False else: raise ValueError("not supported") return NestedTensor(tensor, mask)
[docs] class NestedTensor(object):
[docs] def __init__(self, tensors, mask: Optional[Tensor]): self.tensors = tensors self.mask = mask
def to(self, device): ## type: (Device) -> NestedTensor # noqa cast_tensor = self.tensors.to(device) mask = self.mask if mask is not None: assert mask is not None cast_mask = mask.to(device) else: cast_mask = None return NestedTensor(cast_tensor, cast_mask) def decompose(self): return self.tensors, self.mask
[docs] def __repr__(self): return str(self.tensors)
# _onnx_nested_tensor_from_tensor_list() is an implementation of # nested_tensor_from_tensor_list() that is supported by ONNX tracing. @torch.jit.unused def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: max_size_list: List[Tensor] = [] for i in range(tensor_list[0].dim()): max_size_i = torch.max( torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) # type: ignore ).to(torch.int64) max_size_list.append(max_size_i) max_size: Tuple[Tensor, ...] = tuple(max_size_list) # work around for # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) # m[: img.shape[1], :img.shape[2]] = False # which is not yet supported in onnx padded_imgs = [] padded_masks = [] for img in tensor_list: padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] padded_img = torch.nn.functional.pad( img, (0, int(padding[2]), 0, int(padding[1]), 0, int(padding[0])), ) padded_imgs.append(padded_img) m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) padded_mask = torch.nn.functional.pad( m, (0, int(padding[2]), 0, int(padding[1])), "constant", 1, ) padded_masks.append(padded_mask.to(torch.bool)) tensor = torch.stack(padded_imgs) mask = torch.stack(padded_masks) return NestedTensor(tensor, mask=mask)
[docs] def dice_loss( inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, ): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() inputs = inputs.flatten(1) numerator = 2 * (inputs * targets).sum(-1) denominator = inputs.sum(-1) + targets.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) return loss.sum() / num_masks
dice_loss_jit = torch.jit.script(dice_loss) # type: torch.jit.ScriptModule
[docs] def sigmoid_ce_loss( inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, ): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") return loss.mean(1).sum() / num_masks
sigmoid_ce_loss_jit = torch.jit.script(sigmoid_ce_loss) # type: torch.jit.ScriptModule
[docs] def calculate_uncertainty(logits): """ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the foreground class in `classes`. Args: logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is the number of foreground classes. The values are logits. Returns: scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most uncertain locations having the highest uncertainty score. """ assert logits.shape[1] == 1 gt_class_logits = logits.clone() return -(torch.abs(gt_class_logits))
[docs] class SetCriterion(nn.Module): """This class computes the loss for DETR. The process happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) """
[docs] def __init__( self, num_classes, matcher, weight_dict, eos_coef, losses, num_points, oversample_ratio, importance_sample_ratio, ): """Create the criterion. Parameters: num_classes: number of object categories, omitting the special no-object category matcher: module able to compute a matching between targets and proposals weight_dict: dict containing as key the names of the losses and as values their relative weight. eos_coef: relative classification weight applied to the no-object category losses: list of all the losses to be applied. See get_loss for list of available losses. """ super().__init__() self.num_classes = num_classes self.matcher = matcher self.weight_dict = weight_dict self.eos_coef = eos_coef self.losses = losses empty_weight = torch.ones(self.num_classes + 1) empty_weight[-1] = self.eos_coef self.register_buffer("empty_weight", empty_weight) # pointwise mask loss parameters self.num_points = num_points self.oversample_ratio = oversample_ratio self.importance_sample_ratio = importance_sample_ratio
[docs] def loss_labels(self, outputs, targets, indices, num_masks): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ assert "pred_logits" in outputs src_logits = outputs["pred_logits"].float() idx = self._get_src_permutation_idx(indices) target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full( src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device ) target_classes[idx] = target_classes_o loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) losses = {"loss_ce": loss_ce} return losses
[docs] def loss_masks(self, outputs, targets, indices, num_masks): """Compute the losses related to the masks: the focal loss and the dice loss. targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] """ assert "pred_masks" in outputs src_idx = self._get_src_permutation_idx(indices) tgt_idx = self._get_tgt_permutation_idx(indices) src_masks = outputs["pred_masks"] src_masks = src_masks[src_idx] masks = [t["masks"] for t in targets] # TODO use valid to mask invalid areas due to padding in loss target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() target_masks = target_masks.to(src_masks) target_masks = target_masks[tgt_idx] # No need to upsample predictions as we are using normalized coordinates :) # N x 1 x H x W src_masks = src_masks[:, None] target_masks = target_masks[:, None] with torch.no_grad(): # sample point_coords point_coords = get_uncertain_point_coords_with_randomness( src_masks, lambda logits: calculate_uncertainty(logits), self.num_points, self.oversample_ratio, self.importance_sample_ratio, ) # get gt labels point_labels = point_sample( target_masks, point_coords, align_corners=False, ).squeeze(1) point_logits = point_sample( src_masks, point_coords, align_corners=False, ).squeeze(1) losses = { "loss_mask": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks), "loss_dice": dice_loss_jit(point_logits, point_labels, num_masks), } del src_masks del target_masks return losses
def _get_src_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def _get_tgt_permutation_idx(self, indices): # permute targets following indices batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx def get_loss(self, loss, outputs, targets, indices, num_masks): loss_map = { "labels": self.loss_labels, "masks": self.loss_masks, } assert loss in loss_map, f"do you really want to compute {loss} loss?" return loss_map[loss](outputs, targets, indices, num_masks)
[docs] def forward(self, outputs, targets): """This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} # Retrieve the matching between the outputs of the last layer and the targets indices = self.matcher(outputs_without_aux, targets) # Compute the average number of target boxes accross all nodes, for normalization purposes num_masks = sum(len(t["labels"]) for t in targets) num_masks = torch.as_tensor( [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device ) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_masks) num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() # Compute all the requested losses losses = {} for loss in self.losses: losses.update(self.get_loss(loss, outputs, targets, indices, num_masks)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if "aux_outputs" in outputs: for i, aux_outputs in enumerate(outputs["aux_outputs"]): indices = self.matcher(aux_outputs, targets) for loss in self.losses: l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks) l_dict = {k + f"_{i}": v for k, v in l_dict.items()} losses.update(l_dict) return losses
[docs] def __repr__(self): head = "Criterion " + self.__class__.__name__ body = [ "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)), "losses: {}".format(self.losses), "weight_dict: {}".format(self.weight_dict), "num_classes: {}".format(self.num_classes), "eos_coef: {}".format(self.eos_coef), "num_points: {}".format(self.num_points), "oversample_ratio: {}".format(self.oversample_ratio), "importance_sample_ratio: {}".format(self.importance_sample_ratio), ] _repr_indent = 4 lines = [head] + [" " * _repr_indent + line for line in body] return "\n".join(lines)
[docs] def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() inputs = inputs.flatten(1) numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] loss = 1 - (numerator + 1) / (denominator + 1) return loss
batch_dice_loss_jit = torch.jit.script(batch_dice_loss) # type: torch.jit.ScriptModule
[docs] def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ hw = inputs.shape[1] pos = F.binary_cross_entropy_with_logits(inputs, torch.ones_like(inputs), reduction="none") neg = F.binary_cross_entropy_with_logits(inputs, torch.zeros_like(inputs), reduction="none") loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum("nc,mc->nm", neg, (1 - targets)) return loss / hw
batch_sigmoid_ce_loss_jit = torch.jit.script(batch_sigmoid_ce_loss) # type: torch.jit.ScriptModule
[docs] class HungarianMatcher(nn.Module): """This class computes an assignment between the targets and the predictions of the network For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are un-matched (and thus treated as non-objects). """
[docs] def __init__( self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0 ): """Creates the matcher Params: cost_class: This is the relative weight of the classification error in the matching cost cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost """ super().__init__() self.cost_class = cost_class self.cost_mask = cost_mask self.cost_dice = cost_dice assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0" self.num_points = num_points
[docs] @torch.no_grad() def memory_efficient_forward(self, outputs, targets): """More memory-friendly matching""" bs, num_queries = outputs["pred_logits"].shape[:2] indices = [] # Iterate through batch size for b in range(bs): out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes] tgt_ids = targets[b]["labels"] # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. # The 1 is a constant that doesn't change the matching, it can be ommitted. cost_class = -out_prob[:, tgt_ids] out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred] # gt masks are already padded when preparing target tgt_mask = targets[b]["masks"].to(out_mask) out_mask = out_mask[:, None] tgt_mask = tgt_mask[:, None] # all masks share the same set of points for efficient matching! point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device) # get gt labels tgt_mask = point_sample( tgt_mask, point_coords.repeat(tgt_mask.shape[0], 1, 1), align_corners=False, ).squeeze(1) out_mask = point_sample( out_mask, point_coords.repeat(out_mask.shape[0], 1, 1), align_corners=False, ).squeeze(1) with autocast(enabled=False): out_mask = out_mask.float() tgt_mask = tgt_mask.float() # Compute the focal loss between masks cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) # Compute the dice loss betwen masks with torch.jit.optimized_execution(False): cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) # Final cost matrix C = ( self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice ) C = C.reshape(num_queries, -1).cpu() # -inf values cause error in linear_sum_assignment so replace with large neg if -np.inf in C: C = C[np.where(C == -np.inf)] = -1e9 indices.append(linear_sum_assignment(C)) return [ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices ]
[docs] @torch.no_grad() def forward(self, outputs, targets): """Performs the matching Params: outputs: This is a dict that contains at least these entries: "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth objects in the target) containing the class labels "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks Returns: A list of size batch_size, containing tuples of (index_i, index_j) where: - index_i is the indices of the selected predictions (in order) - index_j is the indices of the corresponding selected targets (in order) For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes) """ return self.memory_efficient_forward(outputs, targets)
[docs] def __repr__(self, _repr_indent=4): head = "Matcher " + self.__class__.__name__ body = [ "cost_class: {}".format(self.cost_class), "cost_mask: {}".format(self.cost_mask), "cost_dice: {}".format(self.cost_dice), ] lines = [head] + [" " * _repr_indent + line for line in body] return "\n".join(lines)