CUDAでFunctorを使う
CUDAの.cuファイルはC++として解釈されるので、templateといった記法やfunctorといったデザインをデバイスコード上でも使うことができます。nvcc 3.0で動作を確認しました。もう2.xの環境が手元になかったので、2.xで使えるかはわかりません。
struct add { template <class T> __device__ T operator ()(T const a, T const b) const { return a + b; } }; struct mul { template <class T> __device__ T operator ()(T const a, T const b) const { return a * b; } }; template <class F> __global__ void element_wise_operation(double* out, double const* a, double const* b, F f) { size_t const i = blockIdx.x * blockDim.x + threadIdx.x; out[i] = f(a[i], b[i]); }
サンプルとして、このようなfunctor classと、それを利用するコードを書きました。を定義しました。operator ()を__device__修飾することにより、__device__や__global__のコードから呼ぶことができるようになります。
呼び出す時は、このようにホスト側でfunctorを生成して渡すことができます。
element_wise_operation<<<1, 10>>>(out, a, b, add()); element_wise_operation<<<1, 10>>>(out, a, b, mul());
また、このように書くことで、パフォーマンスが心配になりますが、functorの呼び出しはインライン展開されるため、普通に書き下した場合と同等のPTXコードが生成されるようです。