MCPcopy
hub / github.com/open-mmlab/mmpose / get_packed_inputs

Function get_packed_inputs

mmpose/testing/_utils.py:81–190  ·  view source on GitHub ↗

Create a dummy batch of model inputs and data samples.

(batch_size=2,
                      num_instances=1,
                      num_keypoints=17,
                      num_levels=1,
                      img_shape=(256, 192),
                      input_size=(192, 256),
                      heatmap_size=(48, 64),
                      simcc_split_ratio=2.0,
                      with_heatmap=True,
                      with_reg_label=True,
                      with_simcc_label=True)

Source from the content-addressed store, hash-verified

79
80
81def get_packed_inputs(batch_size=2,
82 num_instances=1,
83 num_keypoints=17,
84 num_levels=1,
85 img_shape=(256, 192),
86 input_size=(192, 256),
87 heatmap_size=(48, 64),
88 simcc_split_ratio=2.0,
89 with_heatmap=True,
90 with_reg_label=True,
91 with_simcc_label=True):
92 """Create a dummy batch of model inputs and data samples."""
93 rng = np.random.RandomState(0)
94
95 inputs_list = []
96 for idx in range(batch_size):
97 inputs = dict()
98
99 # input
100 h, w = img_shape
101 image = rng.randint(0, 255, size=(3, h, w), dtype=np.uint8)
102 inputs['inputs'] = torch.from_numpy(image)
103
104 # attributes
105 bboxes = _rand_bboxes(rng, num_instances, w, h)
106 bbox_centers, bbox_scales = bbox_xyxy2cs(bboxes)
107
108 keypoints = _rand_keypoints(rng, bboxes, num_keypoints)
109 keypoints_visible = np.ones((num_instances, num_keypoints),
110 dtype=np.float32)
111
112 # meta
113 img_meta = {
114 'id': idx,
115 'img_id': idx,
116 'img_path': '<demo>.png',
117 'img_shape': img_shape,
118 'input_size': input_size,
119 'input_center': bbox_centers,
120 'input_scale': bbox_scales,
121 'flip': False,
122 'flip_direction': None,
123 'flip_indices': list(range(num_keypoints))
124 }
125
126 np.random.shuffle(img_meta['flip_indices'])
127 data_sample = PoseDataSample(metainfo=img_meta)
128
129 # gt_instance
130 gt_instances = InstanceData()
131 gt_instance_labels = InstanceData()
132
133 # [N, K] -> [N, num_levels, K]
134 # keep the first dimension as the num_instances
135 if num_levels > 1:
136 keypoint_weights = np.tile(keypoints_visible[:, None],
137 (1, num_levels, 1))
138 else:

Callers 15

test_predictMethod · 0.90
test_ttaMethod · 0.90
test_lossMethod · 0.90
_get_data_samplesMethod · 0.90
test_predictMethod · 0.90
test_ttaMethod · 0.90
test_lossMethod · 0.90
test_predictMethod · 0.90
test_ttaMethod · 0.90
test_lossMethod · 0.90
test_predictMethod · 0.90
test_ttaMethod · 0.90

Calls 6

bbox_xyxy2csFunction · 0.90
PoseDataSampleClass · 0.90
MultilevelPixelDataClass · 0.90
_rand_bboxesFunction · 0.85
_rand_keypointsFunction · 0.85
_rand_simcc_labelFunction · 0.85

Tested by 15

test_predictMethod · 0.72
test_ttaMethod · 0.72
test_lossMethod · 0.72
_get_data_samplesMethod · 0.72
test_predictMethod · 0.72
test_ttaMethod · 0.72
test_lossMethod · 0.72
test_predictMethod · 0.72
test_ttaMethod · 0.72
test_lossMethod · 0.72
test_predictMethod · 0.72
test_ttaMethod · 0.72