MCPcopy
hub / github.com/shenweichen/DeepCTR-Torch / FM

Class FM

deepctr_torch/layers/interaction.py:12–34  ·  view source on GitHub ↗

Factorization Machine models pairwise (order-2) feature interactions without linear term and bias. Input shape - 3D tensor with shape: ``(batch_size,field_size,embedding_size)``. Output shape - 2D tensor with shape: ``(batch_size, 1)``. References - [Fa

Source from the content-addressed store, hash-verified

10
11
12class FM(nn.Module):
13 """Factorization Machine models pairwise (order-2) feature interactions
14 without linear term and bias.
15 Input shape
16 - 3D tensor with shape: ``(batch_size,field_size,embedding_size)``.
17 Output shape
18 - 2D tensor with shape: ``(batch_size, 1)``.
19 References
20 - [Factorization Machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf)
21 """
22
23 def __init__(self):
24 super(FM, self).__init__()
25
26 def forward(self, inputs):
27 fm_input = inputs
28
29 square_of_sum = torch.pow(torch.sum(fm_input, dim=1, keepdim=True), 2)
30 sum_of_square = torch.sum(fm_input * fm_input, dim=1, keepdim=True)
31 cross_term = square_of_sum - sum_of_square
32 cross_term = 0.5 * torch.sum(cross_term, dim=2, keepdim=False)
33
34 return cross_term
35
36
37class BiInteractionPooling(nn.Module):

Callers 4

__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected