Internal API to set edge(s) features. `data` is a dictionary from the feature name to feature tensor. Each tensor is of shape (B, D1, D2, ...), where B is the number of edges to be updated, and (D1, D2, ...) be the shape of the edge representation tensor. All update
(self, etid, edges, data)
| 4413 | return self._node_frames[ntid].pop(key) |
| 4414 | |
| 4415 | def _set_e_repr(self, etid, edges, data): |
| 4416 | """Internal API to set edge(s) features. |
| 4417 | |
| 4418 | `data` is a dictionary from the feature name to feature tensor. Each tensor |
| 4419 | is of shape (B, D1, D2, ...), where B is the number of edges to be updated, |
| 4420 | and (D1, D2, ...) be the shape of the edge representation tensor. |
| 4421 | |
| 4422 | All update will be done out of place to work with autograd. |
| 4423 | |
| 4424 | Parameters |
| 4425 | ---------- |
| 4426 | etid : int |
| 4427 | Edge type id. |
| 4428 | edges : edges |
| 4429 | Edges can be either |
| 4430 | |
| 4431 | * A pair of endpoint nodes (u, v), where u is the node ID of source |
| 4432 | node type and v is that of destination node type. |
| 4433 | * A tensor of edge ids of the given type. |
| 4434 | |
| 4435 | The default value is all the edges. |
| 4436 | data : tensor or dict of tensor |
| 4437 | Edge representation. |
| 4438 | """ |
| 4439 | # parse argument |
| 4440 | if not is_all(edges): |
| 4441 | eid = utils.parse_edges_arg_to_eid(self, edges, etid, "edges") |
| 4442 | |
| 4443 | # sanity check |
| 4444 | if not utils.is_dict_like(data): |
| 4445 | raise DGLError( |
| 4446 | "Expect dictionary type for feature data." |
| 4447 | ' Got "%s" instead.' % type(data) |
| 4448 | ) |
| 4449 | |
| 4450 | if is_all(edges): |
| 4451 | num_edges = self._graph.num_edges(etid) |
| 4452 | else: |
| 4453 | num_edges = len(eid) |
| 4454 | for key, val in data.items(): |
| 4455 | nfeats = F.shape(val)[0] |
| 4456 | if nfeats != num_edges: |
| 4457 | raise DGLError( |
| 4458 | "Expect number of features to match number of edges." |
| 4459 | " Got %d and %d instead." % (nfeats, num_edges) |
| 4460 | ) |
| 4461 | if F.context(val) != self.device: |
| 4462 | raise DGLError( |
| 4463 | 'Cannot assign edge feature "{}" on device {} to a graph on' |
| 4464 | " device {}. Call DGLGraph.to() to copy the graph to the" |
| 4465 | " same device.".format(key, F.context(val), self.device) |
| 4466 | ) |
| 4467 | # To prevent users from doing things like: |
| 4468 | # |
| 4469 | # g.pin_memory_() |
| 4470 | # g.edata['x'] = torch.randn(...) |
| 4471 | # sg = g.sample_neighbors(torch.LongTensor([...]).cuda()) |
| 4472 | # sg.edata['x'] # Becomes a CPU tensor even if sg is on GPU due to lazy slicing |