Factorization Machine models pairwise (order-2) feature interactions without linear term and bias. Input shape - 3D tensor with shape: ``(batch_size,field_size,embedding_size)``. Output shape - 2D tensor with shape: ``(batch_size, 1)``. References - [Fa
| 10 | |
| 11 | |
| 12 | class FM(nn.Module): |
| 13 | """Factorization Machine models pairwise (order-2) feature interactions |
| 14 | without linear term and bias. |
| 15 | Input shape |
| 16 | - 3D tensor with shape: ``(batch_size,field_size,embedding_size)``. |
| 17 | Output shape |
| 18 | - 2D tensor with shape: ``(batch_size, 1)``. |
| 19 | References |
| 20 | - [Factorization Machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf) |
| 21 | """ |
| 22 | |
| 23 | def __init__(self): |
| 24 | super(FM, self).__init__() |
| 25 | |
| 26 | def forward(self, inputs): |
| 27 | fm_input = inputs |
| 28 | |
| 29 | square_of_sum = torch.pow(torch.sum(fm_input, dim=1, keepdim=True), 2) |
| 30 | sum_of_square = torch.sum(fm_input * fm_input, dim=1, keepdim=True) |
| 31 | cross_term = square_of_sum - sum_of_square |
| 32 | cross_term = 0.5 * torch.sum(cross_term, dim=2, keepdim=False) |
| 33 | |
| 34 | return cross_term |
| 35 | |
| 36 | |
| 37 | class BiInteractionPooling(nn.Module): |