MCPcopy
hub / github.com/Pointcept/PointTransformerV3 / __init__

Method __init__

model.py:320–370  ·  view source on GitHub ↗
(
        self,
        channels,
        num_heads,
        patch_size,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
        order_index=0,
        enable_rpe=False,
        enable_flash=True,
        upcast_attention=True,
        upcast_softmax=True,
    )

Source from the content-addressed store, hash-verified

318
319class SerializedAttention(PointModule):
320 def __init__(
321 self,
322 channels,
323 num_heads,
324 patch_size,
325 qkv_bias=True,
326 qk_scale=None,
327 attn_drop=0.0,
328 proj_drop=0.0,
329 order_index=0,
330 enable_rpe=False,
331 enable_flash=True,
332 upcast_attention=True,
333 upcast_softmax=True,
334 ):
335 super().__init__()
336 assert channels % num_heads == 0
337 self.channels = channels
338 self.num_heads = num_heads
339 self.scale = qk_scale or (channels // num_heads) ** -0.5
340 self.order_index = order_index
341 self.upcast_attention = upcast_attention
342 self.upcast_softmax = upcast_softmax
343 self.enable_rpe = enable_rpe
344 self.enable_flash = enable_flash
345 if enable_flash:
346 assert (
347 enable_rpe is False
348 ), "Set enable_rpe to False when enable Flash Attention"
349 assert (
350 upcast_attention is False
351 ), "Set upcast_attention to False when enable Flash Attention"
352 assert (
353 upcast_softmax is False
354 ), "Set upcast_softmax to False when enable Flash Attention"
355 assert flash_attn is not None, "Make sure flash_attn is installed."
356 self.patch_size = patch_size
357 self.attn_drop = attn_drop
358 else:
359 # when disable flash attention, we still don't want to use mask
360 # consequently, patch size will auto set to the
361 # min number of patch_size_max and number of points
362 self.patch_size_max = patch_size
363 self.patch_size = 0
364 self.attn_drop = torch.nn.Dropout(attn_drop)
365
366 self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias)
367 self.proj = torch.nn.Linear(channels, channels)
368 self.proj_drop = torch.nn.Dropout(proj_drop)
369 self.softmax = torch.nn.Softmax(dim=-1)
370 self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None
371
372 @torch.no_grad()
373 def get_rel_pos(self, point, order):

Callers

nothing calls this directly

Calls 2

RPEClass · 0.85
__init__Method · 0.45

Tested by

no test coverage detected