| 19 | import {computeDispatch, flatDispatchLayout} from './webgpu_util'; |
| 20 | |
| 21 | export class OneHotProgram implements WebGPUProgram { |
| 22 | outputShape: number[]; |
| 23 | shaderKey: string; |
| 24 | dispatchLayout: {x: number[]}; |
| 25 | dispatch: [number, number, number]; |
| 26 | variableNames = ['x']; |
| 27 | uniforms = 'onValue : f32, offValue : f32,'; |
| 28 | workgroupSize: [number, number, number] = [64, 1, 1]; |
| 29 | size = true; |
| 30 | |
| 31 | constructor(numIndices: number, depth: number) { |
| 32 | this.outputShape = [numIndices, depth]; |
| 33 | this.dispatchLayout = flatDispatchLayout(this.outputShape); |
| 34 | this.dispatch = computeDispatch( |
| 35 | this.dispatchLayout, this.outputShape, this.workgroupSize); |
| 36 | this.shaderKey = 'onehot'; |
| 37 | } |
| 38 | |
| 39 | getUserCode(): string { |
| 40 | const userCode = ` |
| 41 | ${main('index')} { |
| 42 | if(index < uniforms.size) { |
| 43 | let coords = getCoordsFromIndex(index); |
| 44 | setOutputAtIndex(index, mix(uniforms.offValue, uniforms.onValue, |
| 45 | f32(i32(round(getX(coords.x))) == coords.y))); |
| 46 | } |
| 47 | } |
| 48 | `; |
| 49 | |
| 50 | return userCode; |
| 51 | } |
| 52 | } |
nothing calls this directly
no outgoing calls
no test coverage detected
searching dependent graphs…