MCPcopy
hub / github.com/open-mmlab/mmpretrain / test_predict

Method test_predict

tests/test_models/test_heads.py:541–591  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

539 self.assertGreater(losses['task1_loss'].item(), 0)
540
541 def test_predict(self):
542 feats = (torch.rand(4, 10), )
543 data_samples = []
544
545 for _ in range(4):
546 data_sample = MultiTaskDataSample()
547 for task_name in self.DEFAULT_ARGS['task_heads']:
548 task_sample = DataSample().set_gt_label(1)
549 data_sample.set_field(task_sample, task_name)
550 data_samples.append(data_sample)
551 head = MODELS.build(self.DEFAULT_ARGS)
552 # without data_samples
553 predictions = head.predict(feats)
554 self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
555 for pred in predictions:
556 self.assertIn('task0', pred)
557 task0_sample = predictions[0].task0
558 self.assertTrue(type(task0_sample.pred_score), 'torch.tensor')
559
560 # with with data_samples
561 predictions = head.predict(feats, data_samples)
562 self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
563 for sample, pred in zip(data_samples, predictions):
564 self.assertIs(sample, pred)
565 self.assertIn('task0', pred)
566
567 # with data samples and nested
568 head_nested = MODELS.build(self.DEFAULT_ARGS2)
569 # adding a None data sample at the beginning
570 data_samples_nested = [None]
571 for _ in range(3):
572 data_sample_nested = MultiTaskDataSample()
573 data_sample_nested0 = MultiTaskDataSample()
574 data_sample_nested0.set_field(DataSample().set_gt_label(1),
575 'task00')
576 data_sample_nested0.set_field(DataSample().set_gt_label(1),
577 'task01')
578 data_sample_nested.set_field(data_sample_nested0, 'task0')
579 data_sample_nested.set_field(DataSample().set_gt_label(1), 'task1')
580 data_samples_nested.append(data_sample_nested)
581
582 predictions = head_nested.predict(feats, data_samples_nested)
583 self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
584 for i in range(3):
585 sample = data_samples_nested[i + 1]
586 pred = predictions[i + 1]
587 self.assertIn('task0', pred)
588 self.assertIn('task1', pred)
589 self.assertIn('task01', pred.get('task0'))
590 self.assertEqual(
591 sample.get('task0').get('task01').gt_label.numpy()[0], 1)
592
593 def test_loss_empty_data_sample(self):
594 feats = (torch.rand(4, 10), )

Callers

nothing calls this directly

Calls 5

MultiTaskDataSampleClass · 0.90
DataSampleClass · 0.90
set_gt_labelMethod · 0.80
getMethod · 0.80
predictMethod · 0.45

Tested by

no test coverage detected