(self)
| 539 | self.assertGreater(losses['task1_loss'].item(), 0) |
| 540 | |
| 541 | def test_predict(self): |
| 542 | feats = (torch.rand(4, 10), ) |
| 543 | data_samples = [] |
| 544 | |
| 545 | for _ in range(4): |
| 546 | data_sample = MultiTaskDataSample() |
| 547 | for task_name in self.DEFAULT_ARGS['task_heads']: |
| 548 | task_sample = DataSample().set_gt_label(1) |
| 549 | data_sample.set_field(task_sample, task_name) |
| 550 | data_samples.append(data_sample) |
| 551 | head = MODELS.build(self.DEFAULT_ARGS) |
| 552 | # without data_samples |
| 553 | predictions = head.predict(feats) |
| 554 | self.assertTrue(is_seq_of(predictions, MultiTaskDataSample)) |
| 555 | for pred in predictions: |
| 556 | self.assertIn('task0', pred) |
| 557 | task0_sample = predictions[0].task0 |
| 558 | self.assertTrue(type(task0_sample.pred_score), 'torch.tensor') |
| 559 | |
| 560 | # with with data_samples |
| 561 | predictions = head.predict(feats, data_samples) |
| 562 | self.assertTrue(is_seq_of(predictions, MultiTaskDataSample)) |
| 563 | for sample, pred in zip(data_samples, predictions): |
| 564 | self.assertIs(sample, pred) |
| 565 | self.assertIn('task0', pred) |
| 566 | |
| 567 | # with data samples and nested |
| 568 | head_nested = MODELS.build(self.DEFAULT_ARGS2) |
| 569 | # adding a None data sample at the beginning |
| 570 | data_samples_nested = [None] |
| 571 | for _ in range(3): |
| 572 | data_sample_nested = MultiTaskDataSample() |
| 573 | data_sample_nested0 = MultiTaskDataSample() |
| 574 | data_sample_nested0.set_field(DataSample().set_gt_label(1), |
| 575 | 'task00') |
| 576 | data_sample_nested0.set_field(DataSample().set_gt_label(1), |
| 577 | 'task01') |
| 578 | data_sample_nested.set_field(data_sample_nested0, 'task0') |
| 579 | data_sample_nested.set_field(DataSample().set_gt_label(1), 'task1') |
| 580 | data_samples_nested.append(data_sample_nested) |
| 581 | |
| 582 | predictions = head_nested.predict(feats, data_samples_nested) |
| 583 | self.assertTrue(is_seq_of(predictions, MultiTaskDataSample)) |
| 584 | for i in range(3): |
| 585 | sample = data_samples_nested[i + 1] |
| 586 | pred = predictions[i + 1] |
| 587 | self.assertIn('task0', pred) |
| 588 | self.assertIn('task1', pred) |
| 589 | self.assertIn('task01', pred.get('task0')) |
| 590 | self.assertEqual( |
| 591 | sample.get('task0').get('task01').gt_label.numpy()[0], 1) |
| 592 | |
| 593 | def test_loss_empty_data_sample(self): |
| 594 | feats = (torch.rand(4, 10), ) |
nothing calls this directly
no test coverage detected