MCPcopy Index your code
hub / github.com/pytorch/pytorch / test_runtime_threads

Method test_runtime_threads

caffe2/python/dataio_test.py:152–200  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

150
151class TestReaderWithLimit(TestCase):
152 def test_runtime_threads(self):
153 ws = workspace.C.Workspace()
154 session = LocalSession(ws)
155 src_ds = make_source_dataset(ws)
156 totals = [None] * 3
157
158 def proc(rec):
159 # executed once
160 with ops.task_init():
161 counter1 = ops.CreateCounter([], ['global_counter'])
162 counter2 = ops.CreateCounter([], ['global_counter2'])
163 counter3 = ops.CreateCounter([], ['global_counter3'])
164 # executed once per thread
165 with ops.task_instance_init():
166 task_counter = ops.CreateCounter([], ['task_counter'])
167 # executed on each iteration
168 ops.CountUp(counter1)
169 ops.CountUp(task_counter)
170 # executed once per thread
171 with ops.task_instance_exit():
172 with ops.loop(ops.RetrieveCount(task_counter)):
173 ops.CountUp(counter2)
174 ops.CountUp(counter3)
175 # executed once
176 with ops.task_exit():
177 totals[0] = final_output(ops.RetrieveCount(counter1))
178 totals[1] = final_output(ops.RetrieveCount(counter2))
179 totals[2] = final_output(ops.RetrieveCount(counter3))
180 return rec
181
182 # Read full data set from original reader
183 with TaskGroup() as tg:
184 pipe(src_ds.reader(), num_runtime_threads=8, processor=proc)
185 session.run(tg)
186 self.assertEqual(totals[0].fetch(), 100)
187 self.assertEqual(totals[1].fetch(), 100)
188 self.assertEqual(totals[2].fetch(), 8)
189
190 # Read with a count-limited reader
191 with TaskGroup() as tg:
192 q1 = pipe(src_ds.reader(), num_runtime_threads=2)
193 q2 = pipe(
194 ReaderWithLimit(q1.reader(), num_iter=25),
195 num_runtime_threads=3)
196 pipe(q2, processor=proc, num_runtime_threads=6)
197 session.run(tg)
198 self.assertEqual(totals[0].fetch(), 25)
199 self.assertEqual(totals[1].fetch(), 25)
200 self.assertEqual(totals[2].fetch(), 6)
201
202 def _test_limit_reader_init_shared(self, size):
203 ws = workspace.C.Workspace()

Callers

nothing calls this directly

Calls 9

LocalSessionClass · 0.90
TaskGroupClass · 0.90
pipeFunction · 0.90
ReaderWithLimitClass · 0.90
make_source_datasetFunction · 0.85
readerMethod · 0.45
runMethod · 0.45
assertEqualMethod · 0.45
fetchMethod · 0.45

Tested by

no test coverage detected