MCPcopy
hub / github.com/tensorflow/tfjs / OneHotProgram

Class OneHotProgram

tfjs-backend-webgpu/src/onehot_webgpu.ts:21–52  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

19import {computeDispatch, flatDispatchLayout} from './webgpu_util';
20
21export class OneHotProgram implements WebGPUProgram {
22 outputShape: number[];
23 shaderKey: string;
24 dispatchLayout: {x: number[]};
25 dispatch: [number, number, number];
26 variableNames = ['x'];
27 uniforms = 'onValue : f32, offValue : f32,';
28 workgroupSize: [number, number, number] = [64, 1, 1];
29 size = true;
30
31 constructor(numIndices: number, depth: number) {
32 this.outputShape = [numIndices, depth];
33 this.dispatchLayout = flatDispatchLayout(this.outputShape);
34 this.dispatch = computeDispatch(
35 this.dispatchLayout, this.outputShape, this.workgroupSize);
36 this.shaderKey = 'onehot';
37 }
38
39 getUserCode(): string {
40 const userCode = `
41 ${main('index')} {
42 if(index < uniforms.size) {
43 let coords = getCoordsFromIndex(index);
44 setOutputAtIndex(index, mix(uniforms.offValue, uniforms.onValue,
45 f32(i32(round(getX(coords.x))) == coords.y)));
46 }
47 }
48 `;
49
50 return userCode;
51 }
52}

Callers

nothing calls this directly

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…