CUDAでFunctorを使う

CUDAの.cuファイルはC++として解釈されるので、templateといった記法やfunctorといったデザインをデバイスコード上でも使うことができます。nvcc 3.0で動作を確認しました。もう2.xの環境が手元になかったので、2.xで使えるかはわかりません。

struct add {
    template <class T>
    __device__
    T operator ()(T const a, T const b) const {
        return a + b;
    }
};

struct mul {
    template <class T>
    __device__
    T operator ()(T const a, T const b) const { 
        return a * b;
    }
};

template <class F>
__global__
void element_wise_operation(double* out, double const* a, double const* b, F f) {
    size_t const i = blockIdx.x * blockDim.x + threadIdx.x;
    out[i] = f(a[i], b[i]);
}

サンプルとして、このようなfunctor classと、それを利用するコードを書きました。を定義しました。operator ()を__device__修飾することにより、__device__や__global__のコードから呼ぶことができるようになります。

呼び出す時は、このようにホスト側でfunctorを生成して渡すことができます。

element_wise_operation<<<1, 10>>>(out, a, b, add());
element_wise_operation<<<1, 10>>>(out, a, b, mul());

また、このように書くことで、パフォーマンスが心配になりますが、functorの呼び出しはインライン展開されるため、普通に書き下した場合と同等のPTXコードが生成されるようです。