| 227 | |
| 228 | |
| 229 | class ModuleList(Module): |
| 230 | def __init__(self, modules) -> None: |
| 231 | super(ModuleList, self).__init__() |
| 232 | offset = len(self) |
| 233 | for i, module in enumerate(modules): |
| 234 | self._modules[str(offset + i)] = module |
| 235 | |
| 236 | def _get_abs_string_index(self, idx): |
| 237 | """Get the absolute index for the list of modules.""" |
| 238 | idx = operator.index(idx) |
| 239 | if not (-len(self) <= idx < len(self)): |
| 240 | raise IndexError("index {} is out of range".format(idx)) |
| 241 | if idx < 0: |
| 242 | idx += len(self) |
| 243 | return str(idx) |
| 244 | |
| 245 | def __getitem__(self, idx): |
| 246 | if isinstance(idx, slice): |
| 247 | return self.__class__(list(self._modules.values())[idx]) |
| 248 | else: |
| 249 | return self._modules[self._get_abs_string_index(idx)] |
| 250 | |
| 251 | def __setitem__(self, idx, module) -> None: |
| 252 | idx = self._get_abs_string_index(idx) |
| 253 | return setattr(self, str(idx), module) |
| 254 | |
| 255 | def __len__(self): |
| 256 | return len(self._modules) |
| 257 | |
| 258 | def __repr__(self): |
| 259 | """Return a custom repr for ModuleList that compresses repeated module representations.""" |
| 260 | list_of_reprs = [repr(item) for item in self] |
| 261 | if len(list_of_reprs) == 0: |
| 262 | return self._get_name() + "()" |
| 263 | |
| 264 | start_end_indices = [[0, 0]] |
| 265 | repeated_blocks = [list_of_reprs[0]] |
| 266 | for i, r in enumerate(list_of_reprs[1:], 1): |
| 267 | if r == repeated_blocks[-1]: |
| 268 | start_end_indices[-1][1] += 1 |
| 269 | continue |
| 270 | |
| 271 | start_end_indices.append([i, i]) |
| 272 | repeated_blocks.append(r) |
| 273 | |
| 274 | lines = [] |
| 275 | main_str = self._get_name() + "(" |
| 276 | for (start_id, end_id), b in zip(start_end_indices, repeated_blocks): |
| 277 | local_repr = f"({start_id}): {b}" # default repr |
| 278 | |
| 279 | if start_id != end_id: |
| 280 | n = end_id - start_id + 1 |
| 281 | local_repr = f"({start_id}-{end_id}): {n} x {b}" |
| 282 | |
| 283 | local_repr = _addindent(local_repr, 2) |
| 284 | lines.append(local_repr) |
| 285 | |
| 286 | main_str += "\n " + "\n ".join(lines) + "\n" |
no outgoing calls