* Fetch list of Tensor into the TensorCache. * * @param tensorCacheUrl The cache url. * @param list The list of array data. * @param device The device to store the data to. * @param artifactCache The artifact cache * @param signal An optional AbortSignal to abort the fetch
(
tensorCacheUrl: string,
list: Array<TensorShardEntry>,
device: DLDevice,
artifactCache: ArtifactCacheTemplate,
signal?: AbortSignal,
)
| 1317 | * @param signal An optional AbortSignal to abort the fetch |
| 1318 | */ |
| 1319 | private async fetchTensorCacheInternal( |
| 1320 | tensorCacheUrl: string, |
| 1321 | list: Array<TensorShardEntry>, |
| 1322 | device: DLDevice, |
| 1323 | artifactCache: ArtifactCacheTemplate, |
| 1324 | signal?: AbortSignal, |
| 1325 | ) { |
| 1326 | const perf = compact.getPerformance(); |
| 1327 | const tstart = perf.now(); |
| 1328 | let totalBytes = 0; |
| 1329 | for (let i = 0; i < list.length; ++i) { |
| 1330 | totalBytes += list[i].nbytes; |
| 1331 | } |
| 1332 | let fetchedBytes = 0; |
| 1333 | let fetchedShards = 0; |
| 1334 | let timeElapsed = 0; |
| 1335 | |
| 1336 | const cacheOnly = await artifactCache.hasAllKeys(list.map(key => new URL(key.dataPath, tensorCacheUrl).href)); |
| 1337 | |
| 1338 | // `loading`: we have finished downloading (or already cacheOnly) and are loading onto WebGPU |
| 1339 | const reportCallback = (iter: number, loading = false) => { |
| 1340 | // report |
| 1341 | for (let j = 0; j < this.initProgressCallback.length; ++j) { |
| 1342 | let text: string; |
| 1343 | if (loading) { |
| 1344 | text = "Loading model from cache[" + iter + "/" + list.length + "]: "; |
| 1345 | text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB loaded. " |
| 1346 | text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, " |
| 1347 | text += timeElapsed + " secs elapsed."; |
| 1348 | } else { |
| 1349 | text = "Fetching param cache[" + iter + "/" + list.length + "]: "; |
| 1350 | text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB fetched. " |
| 1351 | text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, " |
| 1352 | text += timeElapsed + " secs elapsed."; |
| 1353 | text += " It can take a while when we first visit this page to populate the cache." |
| 1354 | text += " Later refreshes will become faster."; |
| 1355 | } |
| 1356 | this.initProgressCallback[j]({ |
| 1357 | progress: fetchedBytes / totalBytes, |
| 1358 | timeElapsed: timeElapsed, |
| 1359 | text: text |
| 1360 | }); |
| 1361 | } |
| 1362 | }; |
| 1363 | |
| 1364 | for (let j = 0; j < this.initProgressCallback.length; ++j) { |
| 1365 | this.initProgressCallback[j]({ |
| 1366 | progress: fetchedBytes / totalBytes, |
| 1367 | timeElapsed: 0, |
| 1368 | text: "Start to fetch params", |
| 1369 | }); |
| 1370 | } |
| 1371 | |
| 1372 | // First download all shards to cache parallely if not yet in cache |
| 1373 | const downloadCache = async (start: number, end: number) => { |
| 1374 | // Download params [start, end) from `list` |
| 1375 | for (let i = start; i < end; i++) { |
| 1376 | const shard = list[i]; |
no test coverage detected