(args: {
inputs: TransformInputs,
backend: WebGPUBackend,
attrs: TransformAttrs
})
| 21 | import {TransformProgram} from '../transform_webgpu'; |
| 22 | |
| 23 | export function transform(args: { |
| 24 | inputs: TransformInputs, |
| 25 | backend: WebGPUBackend, |
| 26 | attrs: TransformAttrs |
| 27 | }): TensorInfo { |
| 28 | const {inputs, backend, attrs} = args; |
| 29 | const {image, transforms} = inputs; |
| 30 | const {interpolation, fillMode, fillValue, outputShape} = attrs; |
| 31 | |
| 32 | const [batch, imageHeight, imageWidth, numChannels] = image.shape; |
| 33 | const [outHeight, outWidth] = |
| 34 | outputShape != null ? outputShape : [imageHeight, imageWidth]; |
| 35 | const outShape = |
| 36 | [batch, outHeight, outWidth, |
| 37 | numChannels] as [number, number, number, number]; |
| 38 | |
| 39 | const program = new TransformProgram(outShape); |
| 40 | const interpolationModeId = interpolation === 'nearest' ? 1 : 2; |
| 41 | let fillModeId: number; |
| 42 | switch (fillMode) { |
| 43 | case 'constant': |
| 44 | fillModeId = 1; |
| 45 | break; |
| 46 | case 'reflect': |
| 47 | fillModeId = 2; |
| 48 | break; |
| 49 | case 'wrap': |
| 50 | fillModeId = 3; |
| 51 | break; |
| 52 | case 'nearest': |
| 53 | fillModeId = 4; |
| 54 | break; |
| 55 | default: |
| 56 | fillModeId = 1; |
| 57 | break; |
| 58 | } |
| 59 | const uniformData = [ |
| 60 | {type: 'int32', data: [interpolationModeId]}, |
| 61 | {type: 'int32', data: [fillModeId]}, {type: 'float32', data: [fillValue]} |
| 62 | ]; |
| 63 | return backend.runWebGPUProgram( |
| 64 | program, [image, transforms], 'float32', uniformData); |
| 65 | } |
| 66 | |
| 67 | export const transformConfig: KernelConfig = { |
| 68 | kernelName: Transform, |
nothing calls this directly
no test coverage detected
searching dependent graphs…