(...args)
| 18026 | return fn; |
| 18027 | } |
| 18028 | define(...args) { |
| 18029 | if (Array.isArray(args[1])) { |
| 18030 | const [prefix, properties, propResolvers, definitions, defResolvers, self, shouldMangle, operator_set_version] = args; |
| 18031 | torch._C.TORCH_INTERNAL_ASSERT(definitions.length === defResolvers.length); |
| 18032 | torch._C.TORCH_INTERNAL_ASSERT(properties.length === propResolvers.length); |
| 18033 | const functions = []; |
| 18034 | const function_table = new Map(); |
| 18035 | const record_function = (fn) => { |
| 18036 | function_table.set(fn.name(), fn); |
| 18037 | functions.push(fn); |
| 18038 | this.register_function(fn); |
| 18039 | }; |
| 18040 | for (let i = 0; i < properties.length; i++) { |
| 18041 | const property_fns = this.define_property(prefix, properties[i], propResolvers[i], self, function_table, shouldMangle); |
| 18042 | const getter_fn = property_fns.getGetter(); |
| 18043 | const setter_fn = property_fns.getSetter(); |
| 18044 | record_function(getter_fn); |
| 18045 | if (setter_fn) { |
| 18046 | record_function(setter_fn); |
| 18047 | } |
| 18048 | } |
| 18049 | for (let i = 0; i < definitions.length; i++) { |
| 18050 | const fn = this.define(prefix, definitions[i], defResolvers[i], self, function_table, shouldMangle, 'Method', operator_set_version); |
| 18051 | record_function(fn); |
| 18052 | } |
| 18053 | for (const [name, fn] of function_table) { |
| 18054 | if (name === '__init__') { |
| 18055 | fn.ensure_defined(); |
| 18056 | } |
| 18057 | } |
| 18058 | for (const fn of functions) { |
| 18059 | fn.ensure_defined(); |
| 18060 | } |
| 18061 | return functions; |
| 18062 | } else if (args[1] instanceof ast.FunctionDef) { |
| 18063 | const [prefix, def, resolver, self, function_table, shouldMangle, type, operator_set_version] = args; |
| 18064 | const _resolver = self ? resolver : new torch._C.FunctionResolver(resolver, function_table); |
| 18065 | const creator = (method) => { |
| 18066 | return new torch._C.to_ir(def, _resolver, self, method); |
| 18067 | }; |
| 18068 | let name = prefix ? new torch._C.QualifiedName(prefix, def.name) : new torch._C.QualifiedName(def.name); |
| 18069 | if (shouldMangle && this.find_function(name)) { |
| 18070 | name = this.mangle(name); |
| 18071 | } |
| 18072 | const graph = new torch.Graph(); |
| 18073 | graph.set_op_version(operator_set_version); |
| 18074 | const fn = new torch._C.GraphFunction(name, graph, creator); |
| 18075 | fn.__ast__ = def; // remove |
| 18076 | if (self) { |
| 18077 | if (type === 'hook') { |
| 18078 | self.getClassType().addForwardHook(fn); |
| 18079 | } else if (type === 'prehook') { |
| 18080 | self.getClassType().addPreHook(fn); |
| 18081 | } else { |
| 18082 | self.getClassType().addMethod(fn); |
| 18083 | } |
| 18084 | } |
| 18085 | return fn; |
no test coverage detected