(
x: Tensor, dataFormat: DataFormat)
| 38 | * @param dataFormat |
| 39 | */ |
| 40 | export function preprocessConv2DInput( |
| 41 | x: Tensor, dataFormat: DataFormat): Tensor { |
| 42 | // TODO(cais): Cast type to float32 if not. |
| 43 | return tidy(() => { |
| 44 | checkDataFormat(dataFormat); |
| 45 | if (dataFormat === 'channelsFirst') { |
| 46 | return tfc.transpose(x, [0, 2, 3, 1]); // NCHW -> NHWC. |
| 47 | } else { |
| 48 | return x; |
| 49 | } |
| 50 | }); |
| 51 | } |
| 52 | |
| 53 | /** |
| 54 | * Transpose and cast the input before the conv3d. |
no test coverage detected
searching dependent graphs…