| 34 | }; |
| 35 | |
| 36 | export class NodeJSKernelBackend extends KernelBackend { |
| 37 | binding: TFJSBinding; |
| 38 | isGPUPackage: boolean; |
| 39 | isUsingGpuDevice: boolean; |
| 40 | private tensorMap: tf.DataStorage<TensorData>; |
| 41 | |
| 42 | constructor(binding: TFJSBinding, packageName: string) { |
| 43 | super(); |
| 44 | this.binding = binding; |
| 45 | this.isGPUPackage = packageName === '@tensorflow/tfjs-node-gpu'; |
| 46 | this.isUsingGpuDevice = this.binding.isUsingGpuDevice(); |
| 47 | this.tensorMap = new tf.DataStorage<TensorData>(this, tf.engine()); |
| 48 | } |
| 49 | |
| 50 | getDTypeInteger(dtype: DataType): number { |
| 51 | switch (dtype) { |
| 52 | case 'float32': |
| 53 | return this.binding.TF_FLOAT; |
| 54 | case 'int32': |
| 55 | return this.binding.TF_INT32; |
| 56 | case 'bool': |
| 57 | return this.binding.TF_BOOL; |
| 58 | case 'complex64': |
| 59 | return this.binding.TF_COMPLEX64; |
| 60 | case 'string': |
| 61 | return this.binding.TF_STRING; |
| 62 | default: |
| 63 | throw new Error(`Unsupported DType: ${dtype}`); |
| 64 | } |
| 65 | } |
| 66 | |
| 67 | private typeAttributeFromTensor(value: Tensor): number { |
| 68 | return this.getDTypeInteger(value.dtype); |
| 69 | } |
| 70 | |
| 71 | // Creates a new Tensor and maps the dataId to the passed in ID. |
| 72 | private createOutputTensor(metadata: TensorMetadata): Tensor { |
| 73 | const newId = {}; |
| 74 | |
| 75 | this.tensorMap.set(newId, { |
| 76 | shape: metadata.shape, |
| 77 | dtype: metadata.dtype, |
| 78 | id: metadata.id, |
| 79 | values: null, |
| 80 | refCount: 1 |
| 81 | }); |
| 82 | |
| 83 | let dtype: DataType; |
| 84 | switch (metadata.dtype) { |
| 85 | case this.binding.TF_FLOAT: |
| 86 | dtype = 'float32'; |
| 87 | break; |
| 88 | case this.binding.TF_INT32: |
| 89 | dtype = 'int32'; |
| 90 | break; |
| 91 | case this.binding.TF_INT64: |
| 92 | console.warn('INT64 output tensor will be stored as BigInt64Array.'); |
| 93 | // INT64 is not supported in TFJS yet, cast it to int32. |
nothing calls this directly
no outgoing calls
no test coverage detected
searching dependent graphs…