Load the kernel according to the current machine. Args: ext_name (str): the name of the extension to be loaded. If not specified, the loader will try to look for an kernel available on the current machine.
(self, ext_name: str = None)
| 51 | cls.REGISTRY.append(extension) |
| 52 | |
| 53 | def load(self, ext_name: str = None): |
| 54 | """ |
| 55 | Load the kernel according to the current machine. |
| 56 | |
| 57 | Args: |
| 58 | ext_name (str): the name of the extension to be loaded. If not specified, the loader |
| 59 | will try to look for an kernel available on the current machine. |
| 60 | """ |
| 61 | exts = [ext_cls() for ext_cls in self.__class__.REGISTRY] |
| 62 | |
| 63 | # look for exts which can be built/loaded on the current machine |
| 64 | |
| 65 | if ext_name: |
| 66 | usable_exts = list(filter(lambda ext: ext.name == ext_name, exts)) |
| 67 | else: |
| 68 | usable_exts = [] |
| 69 | for ext in exts: |
| 70 | if ext.is_available(): |
| 71 | # make sure the machine is compatible during kernel loading |
| 72 | ext.assert_compatible() |
| 73 | usable_exts.append(ext) |
| 74 | |
| 75 | assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine." |
| 76 | |
| 77 | if len(usable_exts) > 1: |
| 78 | # if more than one usable kernel is found, we will try to load the kernel with the highest priority |
| 79 | usable_exts = sorted(usable_exts, key=lambda ext: ext.priority, reverse=True) |
| 80 | warnings.warn( |
| 81 | f"More than one kernel is available, loading the kernel with the highest priority - {usable_exts[0].__class__.__name__}" |
| 82 | ) |
| 83 | return usable_exts[0].load() |
| 84 | |
| 85 | |
| 86 | class CPUAdamLoader(KernelLoader): |
no test coverage detected