Source code for torchsig.models.spectrogram_models.detr.detr

import os.path
from typing import Dict

import numpy as np
import timm
import torch
from torch import nn

from .modules import *
from .utils import *

__all__ = [
    "DETR"
]

supported_detr_models = [
    "detr_b0_nano",
    "detr_b2_nano",
    "detr_b4_nano",
]

[docs] def DETR( version: str, num_classes: int = 1, drop_rate_backbone: float = 0.2, drop_path_rate_backbone: float = 0.2, drop_path_rate_transformer: float = 0.1, ): """Constructs a DETR architecture with an EfficientNet-B0 backbone and an XCiT-Nano transformer. DETR from `"End-to-End Object Detection with Transformers" <https://arxiv.org/pdf/2005.12872.pdf>`_. EfficientNet from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" <https://arxiv.org/pdf/1905.11946.pdf>`_. XCiT from `"XCiT: Cross-Covariance Image Transformers" <https://arxiv.org/pdf/2106.09681.pdf>`_. Args: version (str): Which DETR model to load. num_classes (int): Number of output classes; if loading checkpoint and number does not equal 1, final layer will not be loaded from checkpoint drop_path_rate_backbone (float): Backbone drop path rate for training drop_rate_backbone (float): Backbone dropout rate for training drop_path_rate_transformer (float): Transformer drop path rate for training """ if not version in supported_detr_models: raise ValueError(f"Invalid DETR model: {version}") if "b0" in version: backbone = "efficientnet_b0" elif "b2" in version: backbone = "efficientnet_b2" elif "b4" in version: backbone = "efficientnet_b4" else: raise ValueError(f"Invalid DETR model: {version}") if "nano" in version: transformer = "xcit-nano" else: raise ValueError(f"Invalid DETR model: {version}") # Create DETR mdl = create_detr( backbone=backbone, transformer=transformer, num_classes=num_classes, num_objects=50, hidden_dim=256, drop_rate_backbone=drop_rate_backbone, drop_path_rate_backbone=drop_path_rate_backbone, drop_path_rate_transformer=drop_path_rate_transformer, ds_rate_transformer=2, ds_method_transformer="chunker", ) return mdl