(self, mode = "numpy", param_dict = None)
| 35 | f.close() |
| 36 | |
| 37 | def get_parameters(self, mode = "numpy", param_dict = None): |
| 38 | all_param_dict = self.state_dict() |
| 39 | if param_dict == None: |
| 40 | param_dict = all_param_dict.keys() |
| 41 | res = {} |
| 42 | for param in param_dict: |
| 43 | if mode == "numpy": |
| 44 | res[param] = all_param_dict[param].cpu().numpy() |
| 45 | elif mode == "list": |
| 46 | res[param] = all_param_dict[param].cpu().numpy().tolist() |
| 47 | else: |
| 48 | res[param] = all_param_dict[param] |
| 49 | return res |
| 50 | |
| 51 | def set_parameters(self, parameters): |
| 52 | for i in parameters: |
no outgoing calls
no test coverage detected