Get summary statistics for a specific operation type. Args: operation_name (str): Name of the communication operation Returns: dict: Summary statistics for the operation, or None if operation not found
(self, operation_name)
| 169 | return total |
| 170 | |
| 171 | def get_operation_summary(self, operation_name): |
| 172 | """ |
| 173 | Get summary statistics for a specific operation type. |
| 174 | |
| 175 | Args: |
| 176 | operation_name (str): Name of the communication operation |
| 177 | |
| 178 | Returns: |
| 179 | dict: Summary statistics for the operation, or None if operation not found |
| 180 | """ |
| 181 | if operation_name not in self.comms_dict: |
| 182 | return None |
| 183 | |
| 184 | from deepspeed.utils.timer import trim_mean |
| 185 | |
| 186 | # Create a snapshot to avoid concurrent modification issues |
| 187 | op_data = self.comms_dict[operation_name].copy() |
| 188 | summary = {} |
| 189 | |
| 190 | for msg_size, vals in op_data.items(): |
| 191 | count = vals[0] |
| 192 | total_lat = sum(vals[1]) |
| 193 | avg_lat = trim_mean(vals[1], 0.1) |
| 194 | avg_algbw = trim_mean(vals[2], 0.1) |
| 195 | avg_busbw = trim_mean(vals[3], 0.1) |
| 196 | |
| 197 | summary[msg_size] = { |
| 198 | "count": count, |
| 199 | "total_latency_ms": total_lat, |
| 200 | "avg_latency_ms": avg_lat, |
| 201 | "tput_avg_gbps": avg_algbw, |
| 202 | "busbw_avg_gbps": avg_busbw, |
| 203 | "msg_size_bytes": msg_size, |
| 204 | "msg_size_str": convert_size(msg_size) |
| 205 | } |
| 206 | |
| 207 | return summary |
| 208 | |
| 209 | # Print summary at end of iteration, epoch, or training |
| 210 | def log_all(self, print_log=True, show_straggler=False, return_dict=False): |
nothing calls this directly
no test coverage detected