(self)
| 16 | |
| 17 | class TestLocalSession(TestCase): |
| 18 | def test_local_session(self): |
| 19 | init_net = core.Net('init') |
| 20 | src_values = Struct( |
| 21 | ('uid', np.array([1, 2, 6])), |
| 22 | ('value', np.array([1.4, 1.6, 1.7]))) |
| 23 | expected_dst = Struct( |
| 24 | ('uid', np.array([2, 4, 12])), |
| 25 | ('value', np.array([0.0, 0.0, 0.0]))) |
| 26 | |
| 27 | with core.NameScope('init'): |
| 28 | src_blobs = NewRecord(init_net, src_values) |
| 29 | dst_blobs = InitEmptyRecord(init_net, src_values.clone_schema()) |
| 30 | |
| 31 | def proc1(rec): |
| 32 | net = core.Net('proc1') |
| 33 | with core.NameScope('proc1'): |
| 34 | out = NewRecord(net, rec) |
| 35 | net.Add([rec.uid(), rec.uid()], [out.uid()]) |
| 36 | out.value.set(blob=rec.value(), unsafe=True) |
| 37 | return [net], out |
| 38 | |
| 39 | def proc2(rec): |
| 40 | net = core.Net('proc2') |
| 41 | with core.NameScope('proc2'): |
| 42 | out = NewRecord(net, rec) |
| 43 | out.uid.set(blob=rec.uid(), unsafe=True) |
| 44 | net.Sub([rec.value(), rec.value()], [out.value()]) |
| 45 | return [net], out |
| 46 | |
| 47 | src_ds = Dataset(src_blobs) |
| 48 | dst_ds = Dataset(dst_blobs) |
| 49 | |
| 50 | with TaskGroup() as tg: |
| 51 | out1 = pipe(src_ds.reader(), processor=proc1) |
| 52 | out2 = pipe(out1, processor=proc2) |
| 53 | pipe(out2, dst_ds.writer()) |
| 54 | |
| 55 | ws = workspace.C.Workspace() |
| 56 | FeedRecord(src_blobs, src_values, ws) |
| 57 | session = LocalSession(ws) |
| 58 | session.run(init_net) |
| 59 | session.run(tg) |
| 60 | output = FetchRecord(dst_blobs, ws=ws) |
| 61 | |
| 62 | for a, b in zip(output.field_blobs(), expected_dst.field_blobs()): |
| 63 | np.testing.assert_array_equal(a, b) |
nothing calls this directly
no test coverage detected