(
self,
head,
features=256,
backbone="vitb_rn50_384",
readout="project",
channels_last=False,
use_bn=False,
**kwargs
)
| 30 | |
| 31 | class DPT(BaseModel): |
| 32 | def __init__( |
| 33 | self, |
| 34 | head, |
| 35 | features=256, |
| 36 | backbone="vitb_rn50_384", |
| 37 | readout="project", |
| 38 | channels_last=False, |
| 39 | use_bn=False, |
| 40 | **kwargs |
| 41 | ): |
| 42 | |
| 43 | super(DPT, self).__init__() |
| 44 | |
| 45 | self.channels_last = channels_last |
| 46 | |
| 47 | # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the |
| 48 | # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments. |
| 49 | hooks = { |
| 50 | "beitl16_512": [5, 11, 17, 23], |
| 51 | "beitl16_384": [5, 11, 17, 23], |
| 52 | "beitb16_384": [2, 5, 8, 11], |
| 53 | "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1] |
| 54 | "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] |
| 55 | "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1] |
| 56 | "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] |
| 57 | "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39] |
| 58 | "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21] |
| 59 | "vitb_rn50_384": [0, 1, 8, 11], |
| 60 | "vitb16_384": [2, 5, 8, 11], |
| 61 | "vitl16_384": [5, 11, 17, 23], |
| 62 | }[backbone] |
| 63 | |
| 64 | if "next_vit" in backbone: |
| 65 | in_features = { |
| 66 | "next_vit_large_6m": [96, 256, 512, 1024], |
| 67 | }[backbone] |
| 68 | else: |
| 69 | in_features = None |
| 70 | |
| 71 | # Instantiate backbone and reassemble blocks |
| 72 | self.pretrained, self.scratch = _make_encoder( |
| 73 | backbone, |
| 74 | features, |
| 75 | False, # Set to true of you want to train from scratch, uses ImageNet weights |
| 76 | groups=1, |
| 77 | expand=False, |
| 78 | exportable=False, |
| 79 | hooks=hooks, |
| 80 | use_readout=readout, |
| 81 | in_features=in_features, |
| 82 | ) |
| 83 | |
| 84 | self.number_layers = len(hooks) if hooks is not None else 4 |
| 85 | size_refinenet3 = None |
| 86 | self.scratch.stem_transpose = None |
| 87 | |
| 88 | if "beit" in backbone: |
| 89 | self.forward_transformer = forward_beit |
no test coverage detected