(self)
| 19 | |
| 20 | class TestPipeline(TestCase): |
| 21 | def test_dequeue_many(self): |
| 22 | init_net = core.Net('init') |
| 23 | N = 17 |
| 24 | NUM_DEQUEUE_RECORDS = 3 |
| 25 | src_values = Struct( |
| 26 | ('uid', np.array(range(N))), |
| 27 | ('value', 0.1 * np.array(range(N)))) |
| 28 | expected_dst = Struct( |
| 29 | ('uid', 2 * np.array(range(N))), |
| 30 | ('value', np.array(N * [0.0]))) |
| 31 | |
| 32 | with core.NameScope('init'): |
| 33 | src_blobs = NewRecord(init_net, src_values) |
| 34 | dst_blobs = InitEmptyRecord(init_net, src_values.clone_schema()) |
| 35 | counter = init_net.Const(0) |
| 36 | ONE = init_net.Const(1) |
| 37 | |
| 38 | def proc1(rec): |
| 39 | with core.NameScope('proc1'): |
| 40 | out = NewRecord(ops, rec) |
| 41 | ops.Add([rec.uid(), rec.uid()], [out.uid()]) |
| 42 | out.value.set(blob=rec.value(), unsafe=True) |
| 43 | return out |
| 44 | |
| 45 | def proc2(rec): |
| 46 | with core.NameScope('proc2'): |
| 47 | out = NewRecord(ops, rec) |
| 48 | out.uid.set(blob=rec.uid(), unsafe=True) |
| 49 | ops.Sub([rec.value(), rec.value()], [out.value()]) |
| 50 | ops.Add([counter, ONE], [counter]) |
| 51 | return out |
| 52 | |
| 53 | src_ds = Dataset(src_blobs) |
| 54 | dst_ds = Dataset(dst_blobs) |
| 55 | |
| 56 | with TaskGroup() as tg: |
| 57 | out1 = pipe( |
| 58 | src_ds.reader(), |
| 59 | output=Queue( |
| 60 | capacity=11, num_dequeue_records=NUM_DEQUEUE_RECORDS), |
| 61 | processor=proc1) |
| 62 | out2 = pipe(out1, processor=proc2) |
| 63 | pipe(out2, dst_ds.writer()) |
| 64 | |
| 65 | ws = workspace.C.Workspace() |
| 66 | FeedRecord(src_blobs, src_values, ws) |
| 67 | session = LocalSession(ws) |
| 68 | session.run(init_net) |
| 69 | session.run(tg) |
| 70 | output = FetchRecord(dst_blobs, ws=ws) |
| 71 | num_dequeues = ws.blobs[str(counter)].fetch() |
| 72 | |
| 73 | self.assertEqual( |
| 74 | num_dequeues, int(math.ceil(float(N) / NUM_DEQUEUE_RECORDS))) |
| 75 | |
| 76 | for a, b in zip(output.field_blobs(), expected_dst.field_blobs()): |
| 77 | np.testing.assert_array_equal(a, b) |
nothing calls this directly
no test coverage detected