MCPcopy
hub / github.com/apple/ml-depth-pro / forward

Method forward

src/depth_pro/network/encoder.py:233–332  ·  view source on GitHub ↗

Encode input at multiple resolutions. Args: ---- x (torch.Tensor): Input image. Returns: ------- Multi resolution encoded features.

(self, x: torch.Tensor)

Source from the content-addressed store, hash-verified

231 return embeddings
232
233 def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
234 """Encode input at multiple resolutions.
235
236 Args:
237 ----
238 x (torch.Tensor): Input image.
239
240 Returns:
241 -------
242 Multi resolution encoded features.
243
244 """
245 batch_size = x.shape[0]
246
247 # Step 0: create a 3-level image pyramid.
248 x0, x1, x2 = self._create_pyramid(x)
249
250 # Step 1: split to create batched overlapped mini-images at the backbone (BeiT/ViT/Dino)
251 # resolution.
252 # 5x5 @ 384x384 at the highest resolution (1536x1536).
253 x0_patches = self.split(x0, overlap_ratio=0.25)
254 # 3x3 @ 384x384 at the middle resolution (768x768).
255 x1_patches = self.split(x1, overlap_ratio=0.5)
256 # 1x1 # 384x384 at the lowest resolution (384x384).
257 x2_patches = x2
258
259 # Concatenate all the sliding window patches and form a batch of size (35=5x5+3x3+1x1).
260 x_pyramid_patches = torch.cat(
261 (x0_patches, x1_patches, x2_patches),
262 dim=0,
263 )
264
265 # Step 2: Run the backbone (BeiT) model and get the result of large batch size.
266 x_pyramid_encodings = self.patch_encoder(x_pyramid_patches)
267 x_pyramid_encodings = self.reshape_feature(
268 x_pyramid_encodings, self.out_size, self.out_size
269 )
270
271 # Step 3: merging.
272 # Merge highres latent encoding.
273 x_latent0_encodings = self.reshape_feature(
274 self.backbone_highres_hook0,
275 self.out_size,
276 self.out_size,
277 )
278 x_latent0_features = self.merge(
279 x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3
280 )
281
282 x_latent1_encodings = self.reshape_feature(
283 self.backbone_highres_hook1,
284 self.out_size,
285 self.out_size,
286 )
287 x_latent1_features = self.merge(
288 x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3
289 )
290

Callers

nothing calls this directly

Calls 4

_create_pyramidMethod · 0.95
splitMethod · 0.95
reshape_featureMethod · 0.95
mergeMethod · 0.95

Tested by

no test coverage detected