MCPcopy Index your code
hub / github.com/pytorch/pytorch / test_local_session

Method test_local_session

caffe2/python/session_test.py:18–63  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

16
17class TestLocalSession(TestCase):
18 def test_local_session(self):
19 init_net = core.Net('init')
20 src_values = Struct(
21 ('uid', np.array([1, 2, 6])),
22 ('value', np.array([1.4, 1.6, 1.7])))
23 expected_dst = Struct(
24 ('uid', np.array([2, 4, 12])),
25 ('value', np.array([0.0, 0.0, 0.0])))
26
27 with core.NameScope('init'):
28 src_blobs = NewRecord(init_net, src_values)
29 dst_blobs = InitEmptyRecord(init_net, src_values.clone_schema())
30
31 def proc1(rec):
32 net = core.Net('proc1')
33 with core.NameScope('proc1'):
34 out = NewRecord(net, rec)
35 net.Add([rec.uid(), rec.uid()], [out.uid()])
36 out.value.set(blob=rec.value(), unsafe=True)
37 return [net], out
38
39 def proc2(rec):
40 net = core.Net('proc2')
41 with core.NameScope('proc2'):
42 out = NewRecord(net, rec)
43 out.uid.set(blob=rec.uid(), unsafe=True)
44 net.Sub([rec.value(), rec.value()], [out.value()])
45 return [net], out
46
47 src_ds = Dataset(src_blobs)
48 dst_ds = Dataset(dst_blobs)
49
50 with TaskGroup() as tg:
51 out1 = pipe(src_ds.reader(), processor=proc1)
52 out2 = pipe(out1, processor=proc2)
53 pipe(out2, dst_ds.writer())
54
55 ws = workspace.C.Workspace()
56 FeedRecord(src_blobs, src_values, ws)
57 session = LocalSession(ws)
58 session.run(init_net)
59 session.run(tg)
60 output = FetchRecord(dst_blobs, ws=ws)
61
62 for a, b in zip(output.field_blobs(), expected_dst.field_blobs()):
63 np.testing.assert_array_equal(a, b)

Callers

nothing calls this directly

Calls 15

readerMethod · 0.95
writerMethod · 0.95
field_blobsMethod · 0.95
StructClass · 0.90
NewRecordFunction · 0.90
InitEmptyRecordFunction · 0.90
DatasetClass · 0.90
TaskGroupClass · 0.90
pipeFunction · 0.90
FeedRecordFunction · 0.90
LocalSessionClass · 0.90
FetchRecordFunction · 0.90

Tested by

no test coverage detected