Get the buffer of a specific item. Args: item (str): The demanded item. ids (list[int]): The demanded ids. num_samples (int, optional): Number of samples to calculate the results. Defaults to None. behavior (str, optional): Beh
(self, item, ids=None, num_samples=None, behavior=None)
| 143 | return outs |
| 144 | |
| 145 | def get(self, item, ids=None, num_samples=None, behavior=None): |
| 146 | """Get the buffer of a specific item. |
| 147 | |
| 148 | Args: |
| 149 | item (str): The demanded item. |
| 150 | ids (list[int]): The demanded ids. |
| 151 | num_samples (int, optional): Number of samples to calculate the |
| 152 | results. Defaults to None. |
| 153 | behavior (str, optional): Behavior to calculate the results. |
| 154 | Options are `mean` | None. Defaults to None. |
| 155 | |
| 156 | Returns: |
| 157 | Tensor: The results of the demanded item. |
| 158 | """ |
| 159 | if ids is None: |
| 160 | ids = self.ids |
| 161 | |
| 162 | outs = [] |
| 163 | for id in ids: |
| 164 | out = self.tracks[id][item] |
| 165 | if isinstance(out, list): |
| 166 | if num_samples is not None: |
| 167 | out = out[-num_samples:] |
| 168 | out = torch.cat(out, dim=0) |
| 169 | if behavior == 'mean': |
| 170 | out = out.mean(dim=0, keepdim=True) |
| 171 | elif behavior is None: |
| 172 | out = out[None] |
| 173 | else: |
| 174 | raise NotImplementedError() |
| 175 | else: |
| 176 | out = out[-1] |
| 177 | outs.append(out) |
| 178 | return torch.cat(outs, dim=0) |
| 179 | |
| 180 | @abstractmethod |
| 181 | def track(self, *args, **kwargs): |
no outgoing calls