Serialize local metadata for scalar all-gather.
(self, *, is_spec_enabled: bool, is_microbatch_enabled: bool)
| 63 | return field_names |
| 64 | |
| 65 | def values(self, *, is_spec_enabled: bool, is_microbatch_enabled: bool) -> list[int]: |
| 66 | """Serialize local metadata for scalar all-gather.""" |
| 67 | raw_values = { |
| 68 | 'is_decoding': int(self.is_decoding), |
| 69 | 'is_dummy': int(self.is_dummy), |
| 70 | 'num_tokens': self.num_tokens, |
| 71 | 'is_sleeping': int(self.is_sleeping), |
| 72 | 'batch_size': self.batch_size, |
| 73 | 'draft_num_tokens': self.draft_num_tokens if self.draft_num_tokens is not None else self.num_tokens, |
| 74 | 'enable_microbatch': int(self.enable_microbatch or False), |
| 75 | } |
| 76 | return [ |
| 77 | raw_values[name] |
| 78 | for name in self.field_names( |
| 79 | is_spec_enabled=is_spec_enabled, |
| 80 | is_microbatch_enabled=is_microbatch_enabled, |
| 81 | ) |
| 82 | ] |
| 83 | |
| 84 | |
| 85 | @dataclass |