MCPcopy Index your code
hub / github.com/pytorch/pytorch / test_dequeue_many

Method test_dequeue_many

caffe2/python/pipeline_test.py:21–77  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

19
20class TestPipeline(TestCase):
21 def test_dequeue_many(self):
22 init_net = core.Net('init')
23 N = 17
24 NUM_DEQUEUE_RECORDS = 3
25 src_values = Struct(
26 ('uid', np.array(range(N))),
27 ('value', 0.1 * np.array(range(N))))
28 expected_dst = Struct(
29 ('uid', 2 * np.array(range(N))),
30 ('value', np.array(N * [0.0])))
31
32 with core.NameScope('init'):
33 src_blobs = NewRecord(init_net, src_values)
34 dst_blobs = InitEmptyRecord(init_net, src_values.clone_schema())
35 counter = init_net.Const(0)
36 ONE = init_net.Const(1)
37
38 def proc1(rec):
39 with core.NameScope('proc1'):
40 out = NewRecord(ops, rec)
41 ops.Add([rec.uid(), rec.uid()], [out.uid()])
42 out.value.set(blob=rec.value(), unsafe=True)
43 return out
44
45 def proc2(rec):
46 with core.NameScope('proc2'):
47 out = NewRecord(ops, rec)
48 out.uid.set(blob=rec.uid(), unsafe=True)
49 ops.Sub([rec.value(), rec.value()], [out.value()])
50 ops.Add([counter, ONE], [counter])
51 return out
52
53 src_ds = Dataset(src_blobs)
54 dst_ds = Dataset(dst_blobs)
55
56 with TaskGroup() as tg:
57 out1 = pipe(
58 src_ds.reader(),
59 output=Queue(
60 capacity=11, num_dequeue_records=NUM_DEQUEUE_RECORDS),
61 processor=proc1)
62 out2 = pipe(out1, processor=proc2)
63 pipe(out2, dst_ds.writer())
64
65 ws = workspace.C.Workspace()
66 FeedRecord(src_blobs, src_values, ws)
67 session = LocalSession(ws)
68 session.run(init_net)
69 session.run(tg)
70 output = FetchRecord(dst_blobs, ws=ws)
71 num_dequeues = ws.blobs[str(counter)].fetch()
72
73 self.assertEqual(
74 num_dequeues, int(math.ceil(float(N) / NUM_DEQUEUE_RECORDS)))
75
76 for a, b in zip(output.field_blobs(), expected_dst.field_blobs()):
77 np.testing.assert_array_equal(a, b)

Callers

nothing calls this directly

Calls 15

ConstMethod · 0.95
readerMethod · 0.95
writerMethod · 0.95
field_blobsMethod · 0.95
StructClass · 0.90
NewRecordFunction · 0.90
InitEmptyRecordFunction · 0.90
DatasetClass · 0.90
TaskGroupClass · 0.90
pipeFunction · 0.90
QueueClass · 0.90
FeedRecordFunction · 0.90

Tested by

no test coverage detected