Source code for torchsig.models.spectrogram_models.detr.utils

from typing import List, Optional

import numpy as np
import torch
import torch.distributed as dist
from torch import nn
from torchvision.ops.boxes import box_area


def _max_by_axis(the_list):
    maxes = the_list[0]
    for sublist in the_list[1:]:
        for index, item in enumerate(sublist):
            maxes[index] = max(maxes[index], item)
    return maxes


[docs] def drop_classifier(parent): return torch.nn.Sequential(*list(parent.children())[:-2])
[docs] def find_output_features(parent, num_features=0): for n, m in parent.named_children(): if type(m) is torch.nn.Conv2d: num_features = m.out_channels else: num_features = find_output_features(m, num_features) return num_features
[docs] def xcit_name_to_timm_name(input_name: str) -> str: if "nano" in input_name: model_name = "xcit_nano_12_p16_224" elif "tiny" in input_name: if "24" in input_name: model_name = "xcit_tiny_24_p16_224" else: model_name = "xcit_tiny_12_p16_224" elif "small" in input_name: model_name = "xcit_small_24_p8_224" elif "medium" in input_name: model_name = "xcit_medium_24_p8_224" elif "large" in input_name: model_name = "xcit_large_24_p8_224" else: raise NotImplementedError("Input transformer not supported.") return model_name
[docs] def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True
[docs] def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size()
[docs] @torch.no_grad() def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" if target.numel() == 0: return [torch.zeros([], device=output.device)] maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res
[docs] def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(-1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=-1)
# modified from torchvision to also return the union
[docs] def box_iou(boxes1, boxes2): area1 = box_area(boxes1) area2 = box_area(boxes2) lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] wh = (rb - lt).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] union = area1[:, None] + area2 - inter iou = inter / union return iou, union
[docs] def generalized_box_iou(boxes1, boxes2): """ Generalized IoU from https://giou.stanford.edu/ The boxes should be in [x0, y0, x1, y1] format Returns a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) """ # degenerate boxes gives inf / nan results # so do an early check assert (boxes1[:, 2:] >= boxes1[:, :2]).all() assert (boxes2[:, 2:] >= boxes2[:, :2]).all() iou, union = box_iou(boxes1, boxes2) lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) wh = (rb - lt).clamp(min=0) # [N,M,2] area = wh[:, :, 0] * wh[:, :, 1] return iou - (area - union) / area
[docs] def format_preds(preds): map_preds = [] for i, (det_logits, det_boxes) in enumerate(zip(preds["pred_logits"], preds["pred_boxes"])): boxes = [] scores = [] labels = [] # Convert DETR output format to expected bboxes num_objs = 0 pred = {} pred["pred_logits"] = det_logits pred["pred_boxes"] = det_boxes det_list = [] for obj_idx in range(pred["pred_logits"].shape[0]): probs = pred["pred_logits"][obj_idx].softmax(-1) max_prob = probs.max().cpu().detach().numpy() max_class = probs.argmax().cpu().detach().numpy() if max_class != (pred["pred_logits"].shape[1] - 1) and max_prob >= 0.5: center_time = pred["pred_boxes"][obj_idx][0] center_freq = pred["pred_boxes"][obj_idx][1] duration = pred["pred_boxes"][obj_idx][2] bandwidth = pred["pred_boxes"][obj_idx][3] # Save to box, score, label lists x1 = max(0, (center_time - duration / 2) * 512) y1 = max(0, (center_freq - bandwidth / 2) * 512) x2 = min(512, (center_time + duration / 2) * 512) y2 = min(512, (center_freq + bandwidth / 2) * 512) boxes.append([x1, y1, x2, y2]) scores.extend([float(max_prob)]) labels.extend([int(max_class)]) curr_pred = dict( boxes=torch.tensor(boxes).to("cuda"), scores=torch.tensor(scores).to("cuda"), labels=torch.IntTensor(labels).to("cuda"), ) map_preds.append(curr_pred) return map_preds
[docs] def format_targets(labels): map_targets = [] for i, label in enumerate(labels): boxes = [] scores = [] labels = [] for label_obj_idx in range(len(label["labels"])): center_time = label["boxes"][label_obj_idx][0] center_freq = label["boxes"][label_obj_idx][1] duration = label["boxes"][label_obj_idx][2] bandwidth = label["boxes"][label_obj_idx][3] class_idx = label["labels"][label_obj_idx] x1 = (center_time - duration / 2) * 512 y1 = (center_freq - bandwidth / 2) * 512 x2 = (center_time + duration / 2) * 512 y2 = (center_freq + bandwidth / 2) * 512 boxes.append([x1, y1, x2, y2]) labels.extend([int(class_idx)]) curr_target = dict( boxes=torch.tensor(boxes).to("cuda"), labels=torch.IntTensor(labels).to("cuda"), ) map_targets.append(curr_target) return map_targets