Function
build_decoder_sample
(sample, dec_ids, dec_position, dec_masks, dec_target, dec_logit_mask)
Source from the content-addressed store, hash-verified
| 299 | |
| 300 | |
| 301 | def build_decoder_sample(sample, dec_ids, dec_position, dec_masks, dec_target, dec_logit_mask): |
| 302 | sample['dec_text'] = np.array(dec_ids) |
| 303 | sample['dec_position'] = np.array(dec_position) |
| 304 | sample['dec_mask'] = np.array(dec_masks) |
| 305 | sample['dec_target'] = np.array(dec_target) |
| 306 | sample['dec_logit_mask'] = np.array(dec_logit_mask) |
| 307 | return sample |
| 308 | |
| 309 | |
| 310 | def my_collate(batch): |
Tested by
no test coverage detected