Sample multiple entities. Args: sample_args: Instance of `SampleArgs`. The min and max entropy and module count will be split up betwene the various entities sampled. values: List of values to sample. Returns: List of `Entity` of the same length as `ty
(self, sample_args, values)
| 338 | .format(value, type(value))) |
| 339 | |
| 340 | def sample(self, sample_args, values): |
| 341 | """Sample multiple entities. |
| 342 | |
| 343 | Args: |
| 344 | sample_args: Instance of `SampleArgs`. The min and max entropy |
| 345 | and module count will be split up betwene the various entities |
| 346 | sampled. |
| 347 | values: List of values to sample. |
| 348 | |
| 349 | Returns: |
| 350 | List of `Entity` of the same length as `types`. |
| 351 | |
| 352 | Raises: |
| 353 | RuntimeError: If one of the modules generates a non-`Entity`. |
| 354 | """ |
| 355 | # Can only sample children once. |
| 356 | assert self._module_count == 1 |
| 357 | assert not self._child_symbols |
| 358 | assert not self._child_entities |
| 359 | |
| 360 | if isinstance(sample_args, PreSampleArgs): |
| 361 | sample_args = sample_args() |
| 362 | sample_args_split = sample_args.split(len(values)) |
| 363 | |
| 364 | def all_symbols(): |
| 365 | return (self._relation_symbols |
| 366 | .union(self._self_symbols) |
| 367 | .union(self._child_symbols)) |
| 368 | |
| 369 | for value, child_sample_args in zip(values, sample_args_split): |
| 370 | if number.is_integer(value): |
| 371 | value = sympy.Integer(value) |
| 372 | |
| 373 | all_symbols_ = all_symbols() |
| 374 | context = Context(all_symbols_) |
| 375 | |
| 376 | if child_sample_args.num_modules == 0: |
| 377 | entity = self._value_entity(value, context) |
| 378 | else: |
| 379 | sampler = self._sampler(value, child_sample_args) |
| 380 | entity = sampler(value, child_sample_args, context) |
| 381 | if not isinstance(entity, Entity): |
| 382 | raise RuntimeError( |
| 383 | 'Expected entity, but got {} instead'.format(entity)) |
| 384 | if (not number.is_integer_or_rational_or_decimal(entity.value) |
| 385 | and not isinstance(entity.value, Polynomial)): |
| 386 | raise RuntimeError('sampler {} returned invalid value of type {}' |
| 387 | .format(sampler, type(entity.value))) |
| 388 | if ((number.is_integer_or_rational_or_decimal(value) |
| 389 | and entity.value != value) |
| 390 | or (isinstance(value, Polynomial) and not np.array_equal( |
| 391 | entity.value.coefficients, value.coefficients))): |
| 392 | raise RuntimeError( |
| 393 | 'entity values differ, sampler={} wanted={} got={}' |
| 394 | .format(sampler, value, entity.value)) |
| 395 | if child_sample_args.num_modules != context.module_count: |
| 396 | raise RuntimeError( |
| 397 | 'unused modules, value={} sample_args={} context.module_count={},' |