(webGPUBackend: WebGPUBackend)
| 351 | }); |
| 352 | |
| 353 | async function parallelCompilationCommon(webGPUBackend: WebGPUBackend) { |
| 354 | const startNumBytes = (tf.memory() as WebGPUMemoryInfo).numBytesInGPU; |
| 355 | const startTensor = tf.memory().numTensors; |
| 356 | const startDataBuckets = webGPUBackend.numDataIds(); |
| 357 | |
| 358 | const a1 = tf.tensor1d([1, 1, 1]); |
| 359 | const b1 = tf.tensor1d([1, 1, 1]); |
| 360 | |
| 361 | // Parallel compile. |
| 362 | tf.env().set('WEBGPU_ENGINE_COMPILE_ONLY', true); |
| 363 | const c1 = tf.add(a1, b1); |
| 364 | await webGPUBackend.checkCompileCompletionAsync(); |
| 365 | |
| 366 | // Actual inference. |
| 367 | tf.env().set('WEBGPU_ENGINE_COMPILE_ONLY', false); |
| 368 | const c2 = tf.add(a1, b1); |
| 369 | expectArraysEqual(await c2.data(), [2, 2, 2]); |
| 370 | |
| 371 | tf.dispose([a1, b1, c1, c2]); |
| 372 | const endNumBytes = (tf.memory() as WebGPUMemoryInfo).numBytesInGPU; |
| 373 | const endTensor = tf.memory().numTensors; |
| 374 | const endDataBuckets = webGPUBackend.numDataIds(); |
| 375 | |
| 376 | // We only check numBytesInGPU. For parallel compilation, |
| 377 | // numBytesInGPUAllocated will be more because of the two pass |
| 378 | // uploadToGPU, but they will all be freed, resulting in endNumbytes equal |
| 379 | // to startNumBytes. |
| 380 | expect(startNumBytes).toEqual(endNumBytes); |
| 381 | expect(startTensor).toEqual(endTensor); |
| 382 | expect(endDataBuckets).toEqual(startDataBuckets); |
| 383 | } |
| 384 | |
| 385 | describeWebGPU('parallel compilation', () => { |
| 386 | let prevBackend: string; |
no test coverage detected
searching dependent graphs…