| 561 | return conv_weights, bn_weights, relu_weights |
| 562 | |
| 563 | def forward(self, x): |
| 564 | H, W = x.size()[2:] |
| 565 | |
| 566 | x = self.init_block(x) |
| 567 | |
| 568 | x1 = self.block1_1(x) |
| 569 | x1 = self.block1_2(x1) |
| 570 | x1 = self.block1_3(x1) |
| 571 | |
| 572 | x2 = self.block2_1(x1) |
| 573 | x2 = self.block2_2(x2) |
| 574 | x2 = self.block2_3(x2) |
| 575 | x2 = self.block2_4(x2) |
| 576 | |
| 577 | x3 = self.block3_1(x2) |
| 578 | x3 = self.block3_2(x3) |
| 579 | x3 = self.block3_3(x3) |
| 580 | x3 = self.block3_4(x3) |
| 581 | |
| 582 | x4 = self.block4_1(x3) |
| 583 | x4 = self.block4_2(x4) |
| 584 | x4 = self.block4_3(x4) |
| 585 | x4 = self.block4_4(x4) |
| 586 | |
| 587 | x_fuses = [] |
| 588 | if self.sa and self.dil is not None: |
| 589 | for i, xi in enumerate([x1, x2, x3, x4]): |
| 590 | x_fuses.append(self.attentions[i](self.dilations[i](xi))) |
| 591 | elif self.sa: |
| 592 | for i, xi in enumerate([x1, x2, x3, x4]): |
| 593 | x_fuses.append(self.attentions[i](xi)) |
| 594 | elif self.dil is not None: |
| 595 | for i, xi in enumerate([x1, x2, x3, x4]): |
| 596 | x_fuses.append(self.dilations[i](xi)) |
| 597 | else: |
| 598 | x_fuses = [x1, x2, x3, x4] |
| 599 | |
| 600 | e1 = self.conv_reduces[0](x_fuses[0]) |
| 601 | e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False) |
| 602 | |
| 603 | e2 = self.conv_reduces[1](x_fuses[1]) |
| 604 | e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False) |
| 605 | |
| 606 | e3 = self.conv_reduces[2](x_fuses[2]) |
| 607 | e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False) |
| 608 | |
| 609 | e4 = self.conv_reduces[3](x_fuses[3]) |
| 610 | e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False) |
| 611 | |
| 612 | outputs = [e1, e2, e3, e4] |
| 613 | |
| 614 | output = self.classifier(torch.cat(outputs, dim=1)) |
| 615 | #if not self.training: |
| 616 | # return torch.sigmoid(output) |
| 617 | |
| 618 | outputs.append(output) |
| 619 | outputs = [torch.sigmoid(r) for r in outputs] |
| 620 | return outputs |