6.1。高级扩展 API

原文: http://numba.pydata.org/numba-doc/latest/extending/high-level.html

此扩展 API 通过 numba.extending 模块公开。

6.1.1。实现功能

@overload装饰器允许您实现在 nopython 模式功能中使用的任意函数。用@overload修饰的函数在编译时使用函数运行时参数的 类型 调用。它应该返回一个 callable,表示给定类型的函数的 实现 。返回的实现由 Numba 编译,就像它是用@jit修饰的普通函数一样。 @jit的其他选项可以使用jit_options参数作为字典传递。

例如,让我们假装 Numba 不支持元组上的 len() 函数。以下是使用@overload实现它的方法:

  1. from numba import types
  2. from numba.extending import overload
  3. @overload(len)
  4. def tuple_len(seq):
  5. if isinstance(seq, types.BaseTuple):
  6. n = len(seq)
  7. def len_impl(seq):
  8. return n
  9. return len_impl

您可能想知道,如果使用除元组以外的其他内容调用 len() 会发生什么?如果用@overload修饰的函数不返回任何内容(即返回 None),则尝试其他定义直到成功。因此,多个库可能会使 len() 超载不同类型,而不会相互冲突。

6.1.2。实施方法

@overload_method装饰器类似地允许在 Numba 众所周知的类型上实现方法。以下示例在 Numpy 数组上实现 take() 方法:

  1. @overload_method(types.Array, 'take')
  2. def array_take(arr, indices):
  3. if isinstance(indices, types.Array):
  4. def take_impl(arr, indices):
  5. n = indices.shape[0]
  6. res = np.empty(n, arr.dtype)
  7. for i in range(n):
  8. res[i] = arr[indices[i]]
  9. return res
  10. return take_impl

6.1.3。实现属性

@overload_attribute装饰器允许在类型上实现数据属性(或属性)。只能读取属性;只有低级 API 支持可写属性。

以下示例在 Numpy 数组上实现 nbytes 属性:

  1. @overload_attribute(types.Array, 'nbytes')
  2. def array_nbytes(arr):
  3. def get(arr):
  4. return arr.size * arr.itemsize
  5. return get

6.1.4。导入 Cython 函数

函数get_cython_function_address获取 Cython 扩展模块中 C 函数的地址。该地址可用于通过 ctypes.CFUNCTYPE() 回调访问 C 函数,从而允许在 Numba jitted 函数中使用 C 函数。例如,假设您有文件foo.pyx

  1. from libc.math cimport exp
  2. cdef api double myexp(double x):
  3. return exp(x)

您可以通过以下方式从 Numba 访问myexp

  1. import ctypes
  2. from numba.extending import get_cython_function_address
  3. addr = get_cython_function_address("foo", "myexp")
  4. functype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)
  5. myexp = functype(addr)

函数myexp现在可以在 jitted 函数中使用,例如:

  1. @njit
  2. def double_myexp(x):
  3. return 2*myexp(x)

需要注意的是,如果您的函数使用 Cython 的融合类型,那么函数的名称将被破坏。要找出函数的错位名称,可以检查扩展模块的__pyx_capi__属性。