()
| 25 | |
| 26 | @clear_cache_before_run() |
| 27 | def test_coloproxy(): |
| 28 | tracer = ColoTracer() |
| 29 | model = Conv1D(3, 3) |
| 30 | input_sample = {"x": torch.rand(3, 3).to("meta")} |
| 31 | |
| 32 | graph = tracer.trace(root=model, meta_args=input_sample) |
| 33 | gm = GraphModule(model, graph, model.__class__.__name__) |
| 34 | gm.recompile() |
| 35 | node = list(gm.graph.nodes)[0] |
| 36 | |
| 37 | proxy = ColoProxy(node=node, tracer=tracer) |
| 38 | proxy.meta_data = torch.empty(4, 2, device="meta") |
| 39 | |
| 40 | assert len(proxy) == 4 |
| 41 | assert proxy.shape[0] == 4 and proxy.shape[1] == 2 |
| 42 | assert proxy.dim() == 2 |
| 43 | assert proxy.dtype == torch.float32 |
| 44 | assert proxy.size(0) == 4 |
| 45 | |
| 46 | |
| 47 | if __name__ == "__main__": |
no test coverage detected
searching dependent graphs…