Mapping shared-memory tensor from server to client. Parameters ---------- partition_book : GraphPartitionBook Store the partition information
(self, partition_book)
| 1281 | self.barrier() |
| 1282 | |
| 1283 | def map_shared_data(self, partition_book): |
| 1284 | """Mapping shared-memory tensor from server to client. |
| 1285 | |
| 1286 | Parameters |
| 1287 | ---------- |
| 1288 | partition_book : GraphPartitionBook |
| 1289 | Store the partition information |
| 1290 | """ |
| 1291 | # Get all partition policies |
| 1292 | for ntype in partition_book.ntypes: |
| 1293 | policy = NodePartitionPolicy(partition_book, ntype) |
| 1294 | self._all_possible_part_policy[policy.policy_str] = policy |
| 1295 | for etype in partition_book.canonical_etypes: |
| 1296 | policy = EdgePartitionPolicy(partition_book, etype) |
| 1297 | self._all_possible_part_policy[policy.policy_str] = policy |
| 1298 | |
| 1299 | # Get shared data from server side |
| 1300 | self.barrier() |
| 1301 | request = GetSharedDataRequest(GET_SHARED_MSG) |
| 1302 | rpc.send_request(self._main_server_id, request) |
| 1303 | response = rpc.recv_response() |
| 1304 | for name, meta in response.meta.items(): |
| 1305 | if name not in self._data_name_list: |
| 1306 | shape, dtype, policy_str = meta |
| 1307 | assert policy_str in self._all_possible_part_policy |
| 1308 | shared_data = empty_shared_mem( |
| 1309 | name + "-kvdata-", False, shape, dtype |
| 1310 | ) |
| 1311 | dlpack = shared_data.to_dlpack() |
| 1312 | self._data_store[name] = F.zerocopy_from_dlpack(dlpack) |
| 1313 | self._part_policy[name] = self._all_possible_part_policy[ |
| 1314 | policy_str |
| 1315 | ] |
| 1316 | self._pull_handlers[name] = default_pull_handler |
| 1317 | self._push_handlers[name] = default_push_handler |
| 1318 | # Get full data shape across servers |
| 1319 | for name, meta in response.meta.items(): |
| 1320 | if name not in self._data_name_list: |
| 1321 | shape, _, _ = meta |
| 1322 | data_shape = list(shape) |
| 1323 | data_shape[0] = 0 |
| 1324 | request = GetPartShapeRequest(name) |
| 1325 | # send request to all main server nodes |
| 1326 | for machine_id in range(self._machine_count): |
| 1327 | server_id = machine_id * self._group_count |
| 1328 | rpc.send_request(server_id, request) |
| 1329 | # recv response from all the main server nodes |
| 1330 | for _ in range(self._machine_count): |
| 1331 | res = rpc.recv_response() |
| 1332 | data_shape[0] += res.shape[0] |
| 1333 | self._full_data_shape[name] = tuple(data_shape) |
| 1334 | # Send meta data to backup servers |
| 1335 | for name, meta in response.meta.items(): |
| 1336 | shape, dtype, policy_str = meta |
| 1337 | request = SendMetaToBackupRequest( |
| 1338 | name, |
| 1339 | dtype, |
| 1340 | shape, |