Execution of Elemwise operators in Onnx (CLIP as an example)

The broad_cast_type type of the Clip operator is BcastType::UNKNOWN_BCAST_TYPE. So it will be called

fallback::ElemwiseImpl::exec(srcs, dst); then because
srcs.size() > 2, call to:
naive::ElemwiseForwardImpl::exec(srcs, dst);

First look at the call stack diagram:

The Elemwise class first defines the operator execution allocator ModeDispatcher, which is called through on_arity_dispatched_cb_dtype, and is actually called by the methods ElemwiseForwardImpl::on_arity_dispatched and ElemwiseForwardImpl::on_arity_dispatched_no_bool.

#define on_arity_dispatched_cb_dtype(_dt) \
    if (m_dst->layout.dtype == _dt()) { \
        using dtrait = DTypeTrait<_dt>; \
        using ctype = dtrait::ctype; \
        return ModeDispatcher<arity, dtrait::category, ctype>::run( \
                static_cast<HandleImpl*>(handle()), src, m_param.mode, *m_dst); \
    }



template <int arity>
void ElemwiseForwardImpl::on_arity_dispatched() {
    printf("********************** on_arity_dispatched\\
");
    auto src = make_elemwise_op_param<arity>();
    MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(on_arity_dispatched_cb_dtype)
    MEGDNN_FOREACH_COMPUTING_DTYPE_INT(on_arity_dispatched_cb_dtype)
    on_arity_dispatched_cb_dtype(::megdnn::dtype::Bool) megdnn_throw("bad dtype");
}

template <int arity>
void ElemwiseForwardImpl::on_arity_dispatched_no_bool() {
    printf("********************** on_arity_dispatched_no_bool\\
");
    auto src = make_elemwise_op_param<arity>();
    MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(on_arity_dispatched_cb_dtype)
    MEGDNN_FOREACH_COMPUTING_DTYPE_INT(on_arity_dispatched_cb_dtype)
    megdnn_throw("bad dtype");
}

The following is the definition of ModeDispatcher:

#define FOREACH MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT
IMPL_MODE_DISPATCHER(3, DTypeCategory::FLOAT);
#define IMPL_MODE_DISPATCHER(_arity, _dtype_cat) \
    template <typename ctype>\
    struct ElemwiseForwardImpl::ModeDispatcher<_arity, _dtype_cat, ctype> { \
        static constexpr int arity = _arity; \
        static void run(\
                HandleImpl* handle, const ElemwiseOpParamN<arity> & amp; src, Mode mode, \
                const TensorND dst) { \
            switch (mode) { \
                FOREACH(_cb_dispatch_mode) \
                default: \
                    megdnn_throw("bad mode"); \
            } \
        } \
    }

#undef FOREACH

#define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \
    MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb)\
    MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)\
    MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb)\
    MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)\
    MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb)

#define MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb) _cb(_mode)

After expanding the code:


template <typename ctype>
struct ElemwiseForwardImpl::ModeDispatcher<_arity, _dtype_cat, ctype> {
    static constexpr int arity = _arity;
    static void run(
                HandleImpl* handle, const ElemwiseOpParamN<arity> & amp; src, Mode mode,
               const TensorND dst) {
          switch (mode) {
                FOREACH(_cb_dispatch_mode)
                default:
                    megdnn_throw("bad mode");
    }
}

expands to:

template <typename ctype>
struct ElemwiseForwardImpl::ModeDispatcher<_arity, _dtype_cat, ctype> {
    static constexpr int arity = _arity;
    static void run(
                HandleImpl* handle, const ElemwiseOpParamN<arity> & amp; src, Mode mode,
               const TensorND dst) {
          switch (mode) {
                _cb_dispatch_mode(COND_LEQ_MOV)
                _cb_dispatch_mode(COND_LT_MOV)
                _cb_dispatch_mode(FUSE_MUL_ADD3)
                _cb_dispatch_mode (CLIP)
                _cb_dispatch_mode(PRELU_GRAD)
                default:
                    megdnn_throw("bad mode");
    }
}

#define _cb_dispatch_mode(_m) \
    case Mode::_m: \
        do { \
            using KernImpl = ElemwiseKern<\
                    megcorePlatformCPU, param_enumv::Elemwise::Mode::_m, ctype>; \
            MIDOUT_BEGIN(\
                    megdnn_naive_elemwise, \
                    midout_iv(param_enumv::Elemwise::Mode::_m)) { \
                auto params = src; \
                MEGDNN_DISPATCH_CPU_KERN(\
                        handle, ElemArithKernCaller<arity MEGDNN_COMMA KernImpl>::run( \
                                        dst.ptr<ctype>(), params)); \
                return; \
            } \
            MIDOUT_END(); \
        } while (0);

Look at the definition of the CLIP operator of ElemwiseKern

struct ElemwiseKern;



DEF_KERN_ALL(CLIP, x <= y ? y : (x <= z ? x : z));

//! define kernel for all ctypes
#define DEF_KERN_ALL(_mode, _imp) \
    DEF_KERN_INT(_mode, _imp); \
    DEF_KERN_FLOAT(_mode, _imp);

//! define kernel for all float types
#define DEF_KERN_FLOAT(_mode, _imp) \
    DEF_KERN(dt_float32, _mode, _imp); \
    DNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \
    DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);)

//! define kernel for all int types
#define DEF_KERN_INT(_mode, _imp) \
    DEF_KERN(dt_int32, _mode, _imp); \
    DEF_KERN(dt_int16, _mode, _imp); \
    DEF_KERN(dt_int8, _mode, _imp); \
    DEF_KERN(dt_uint8, _mode, _imp);


//! define kernel for a single ctype
#define DEF_KERN(_ctype, _mode, _imp) \
    template <megcorePlatform_t plat> \
    struct ElemwiseKern<plat, param_enumv::Elemwise::Mode::_mode, _ctype> { \
        typedef _ctype ctype; \
        static __host__ __device__ _ctype apply(KERN_SIG) { return ctype(_imp); } \
    }

ElemArithKernCaller executes the code for the actual operator:

Among them, KernImpl is the ElemwiseKern defined earlier, through MEGDNN_DISPATCH_CPU_KERN, there is CPU scheduling execution

/*!
 * \brief operator impls should utilize this method to
 * \param _handle a pointer to HandleImpl
 * \param _stmt the statements to be executed for the kernel
 */
#define MEGDNN_DISPATCH_CPU_KERN(_handle, _stmt) \
    do { \
        auto _kern = [=]() { _stmt; }; \
        _handle->dispatch_kern(_kern); \
    } while (0)


template <int arity, class KernImpl>
struct ElemArithKernCaller {
    typedef typename KernImpl::ctype ctype;
    static void run(ctype* dest, const ElemwiseOpParamN<arity> & amp; param);
};


template <class KernImpl>
struct ElemArithKernCaller<1, KernImpl> {
    typedef typename KernImpl::ctype ctype;
    static void run(ctype* dest, const ElemwiseOpParamN<1> & amp; param) {
        auto iter0 = tensor_iter_valonly<ctype>(param[0]).begin();
        for (size_t i = 0; i < param.size; + + i) {
            dest[i] = KernImpl::apply(*iter0);
             + + iter0;
        }
    }
};
template <class KernImpl>
struct ElemArithKernCaller<2, KernImpl> {
    typedef typename KernImpl::ctype ctype;
    static void run(ctype* dest, const ElemwiseOpParamN<2> & amp; param) {
        auto iter0 = tensor_iter_valonly<ctype>(param[0]).begin();
        auto iter1 = tensor_iter_valonly<ctype>(param[1]).begin();
        for (size_t i = 0; i < param.size; + + i) {
            dest[i] = KernImpl::apply(*iter0, *iter1);
             + + iter0;
             + + iter1;
        }
    }
};
template <class KernImpl>
struct ElemArithKernCaller<3, KernImpl> {
    typedef typename KernImpl::ctype ctype;
    static void run(ctype* dest, const ElemwiseOpParamN<3> & amp; param) {
        auto iter0 = tensor_iter_valonly<ctype>(param[0]).begin();
        auto iter1 = tensor_iter_valonly<ctype>(param[1]).begin();
        auto iter2 = tensor_iter_valonly<ctype>(param[2]).begin();
        for (size_t i = 0; i < param.size; + + i) {
            dest[i] = KernImpl::apply(*iter0, *iter1, *iter2);
             + + iter0;
             + + iter1;
             + + iter2;
        }
    }
};

The knowledge points of the article match the official knowledge files, and you can further learn relevant knowledge