Enter a local scope context for the graph. By entering a local scope, any out-place mutation to the feature data will not reflect to the original graph, thus making it easier to use in a function scope (e.g. forward computation of a model). If set, the local scope w
(self)
| 6015 | |
| 6016 | @contextmanager |
| 6017 | def local_scope(self): |
| 6018 | """Enter a local scope context for the graph. |
| 6019 | |
| 6020 | By entering a local scope, any out-place mutation to the feature data will |
| 6021 | not reflect to the original graph, thus making it easier to use in a function scope |
| 6022 | (e.g. forward computation of a model). |
| 6023 | |
| 6024 | If set, the local scope will use same initializers for node features and |
| 6025 | edge features. |
| 6026 | |
| 6027 | Notes |
| 6028 | ----- |
| 6029 | Inplace operations do reflect to the original graph. This function also has little |
| 6030 | overhead when the number of feature tensors in this graph is small. |
| 6031 | |
| 6032 | Examples |
| 6033 | -------- |
| 6034 | |
| 6035 | The following example uses PyTorch backend. |
| 6036 | |
| 6037 | >>> import dgl |
| 6038 | >>> import torch |
| 6039 | |
| 6040 | Create a function for computation on graphs. |
| 6041 | |
| 6042 | >>> def foo(g): |
| 6043 | ... with g.local_scope(): |
| 6044 | ... g.edata['h'] = torch.ones((g.num_edges(), 3)) |
| 6045 | ... g.edata['h2'] = torch.ones((g.num_edges(), 3)) |
| 6046 | ... return g.edata['h'] |
| 6047 | |
| 6048 | ``local_scope`` avoids changing the graph features when exiting the function. |
| 6049 | |
| 6050 | >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([0, 0, 2]))) |
| 6051 | >>> g.edata['h'] = torch.zeros((g.num_edges(), 3)) |
| 6052 | >>> newh = foo(g) |
| 6053 | >>> print(g.edata['h']) # still get tensor of all zeros |
| 6054 | tensor([[0., 0., 0.], |
| 6055 | [0., 0., 0.], |
| 6056 | [0., 0., 0.]]) |
| 6057 | >>> 'h2' in g.edata # new feature set in the function scope is not found |
| 6058 | False |
| 6059 | |
| 6060 | In-place operations will still reflect to the original graph. |
| 6061 | |
| 6062 | >>> def foo(g): |
| 6063 | ... with g.local_scope(): |
| 6064 | ... # in-place operation |
| 6065 | ... g.edata['h'] += 1 |
| 6066 | ... return g.edata['h'] |
| 6067 | |
| 6068 | >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([0, 0, 2]))) |
| 6069 | >>> g.edata['h'] = torch.zeros((g.num_edges(), 1)) |
| 6070 | >>> newh = foo(g) |
| 6071 | >>> print(g.edata['h']) # the result changes |
| 6072 | tensor([[1.], |
| 6073 | [1.], |
| 6074 | [1.]]) |