(self)
| 150 | |
| 151 | class TestReaderWithLimit(TestCase): |
| 152 | def test_runtime_threads(self): |
| 153 | ws = workspace.C.Workspace() |
| 154 | session = LocalSession(ws) |
| 155 | src_ds = make_source_dataset(ws) |
| 156 | totals = [None] * 3 |
| 157 | |
| 158 | def proc(rec): |
| 159 | # executed once |
| 160 | with ops.task_init(): |
| 161 | counter1 = ops.CreateCounter([], ['global_counter']) |
| 162 | counter2 = ops.CreateCounter([], ['global_counter2']) |
| 163 | counter3 = ops.CreateCounter([], ['global_counter3']) |
| 164 | # executed once per thread |
| 165 | with ops.task_instance_init(): |
| 166 | task_counter = ops.CreateCounter([], ['task_counter']) |
| 167 | # executed on each iteration |
| 168 | ops.CountUp(counter1) |
| 169 | ops.CountUp(task_counter) |
| 170 | # executed once per thread |
| 171 | with ops.task_instance_exit(): |
| 172 | with ops.loop(ops.RetrieveCount(task_counter)): |
| 173 | ops.CountUp(counter2) |
| 174 | ops.CountUp(counter3) |
| 175 | # executed once |
| 176 | with ops.task_exit(): |
| 177 | totals[0] = final_output(ops.RetrieveCount(counter1)) |
| 178 | totals[1] = final_output(ops.RetrieveCount(counter2)) |
| 179 | totals[2] = final_output(ops.RetrieveCount(counter3)) |
| 180 | return rec |
| 181 | |
| 182 | # Read full data set from original reader |
| 183 | with TaskGroup() as tg: |
| 184 | pipe(src_ds.reader(), num_runtime_threads=8, processor=proc) |
| 185 | session.run(tg) |
| 186 | self.assertEqual(totals[0].fetch(), 100) |
| 187 | self.assertEqual(totals[1].fetch(), 100) |
| 188 | self.assertEqual(totals[2].fetch(), 8) |
| 189 | |
| 190 | # Read with a count-limited reader |
| 191 | with TaskGroup() as tg: |
| 192 | q1 = pipe(src_ds.reader(), num_runtime_threads=2) |
| 193 | q2 = pipe( |
| 194 | ReaderWithLimit(q1.reader(), num_iter=25), |
| 195 | num_runtime_threads=3) |
| 196 | pipe(q2, processor=proc, num_runtime_threads=6) |
| 197 | session.run(tg) |
| 198 | self.assertEqual(totals[0].fetch(), 25) |
| 199 | self.assertEqual(totals[1].fetch(), 25) |
| 200 | self.assertEqual(totals[2].fetch(), 6) |
| 201 | |
| 202 | def _test_limit_reader_init_shared(self, size): |
| 203 | ws = workspace.C.Workspace() |
nothing calls this directly
no test coverage detected