The broad_cast_type type of the Clip operator is BcastType::UNKNOWN_BCAST_TYPE. So it will be called
fallback::ElemwiseImpl::exec(srcs, dst); then because
srcs.size() > 2, call to:
naive::ElemwiseForwardImpl::exec(srcs, dst);
First look at the call stack diagram:
The Elemwise class first defines the operator execution allocator ModeDispatcher, which is called through on_arity_dispatched_cb_dtype, and is actually called by the methods ElemwiseForwardImpl::on_arity_dispatched and ElemwiseForwardImpl::on_arity_dispatched_no_bool.
#define on_arity_dispatched_cb_dtype(_dt) \ if (m_dst->layout.dtype == _dt()) { \ using dtrait = DTypeTrait<_dt>; \ using ctype = dtrait::ctype; \ return ModeDispatcher<arity, dtrait::category, ctype>::run( \ static_cast<HandleImpl*>(handle()), src, m_param.mode, *m_dst); \ } template <int arity> void ElemwiseForwardImpl::on_arity_dispatched() { printf("********************** on_arity_dispatched\\ "); auto src = make_elemwise_op_param<arity>(); MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(on_arity_dispatched_cb_dtype) MEGDNN_FOREACH_COMPUTING_DTYPE_INT(on_arity_dispatched_cb_dtype) on_arity_dispatched_cb_dtype(::megdnn::dtype::Bool) megdnn_throw("bad dtype"); } template <int arity> void ElemwiseForwardImpl::on_arity_dispatched_no_bool() { printf("********************** on_arity_dispatched_no_bool\\ "); auto src = make_elemwise_op_param<arity>(); MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(on_arity_dispatched_cb_dtype) MEGDNN_FOREACH_COMPUTING_DTYPE_INT(on_arity_dispatched_cb_dtype) megdnn_throw("bad dtype"); }
The following is the definition of ModeDispatcher:
#define FOREACH MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT IMPL_MODE_DISPATCHER(3, DTypeCategory::FLOAT); #define IMPL_MODE_DISPATCHER(_arity, _dtype_cat) \ template <typename ctype>\ struct ElemwiseForwardImpl::ModeDispatcher<_arity, _dtype_cat, ctype> { \ static constexpr int arity = _arity; \ static void run(\ HandleImpl* handle, const ElemwiseOpParamN<arity> & amp; src, Mode mode, \ const TensorND dst) { \ switch (mode) { \ FOREACH(_cb_dispatch_mode) \ default: \ megdnn_throw("bad mode"); \ } \ } \ } #undef FOREACH #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb)\ MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)\ MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb)\ MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)\ MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) #define MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb) _cb(_mode)
After expanding the code:
template <typename ctype> struct ElemwiseForwardImpl::ModeDispatcher<_arity, _dtype_cat, ctype> { static constexpr int arity = _arity; static void run( HandleImpl* handle, const ElemwiseOpParamN<arity> & amp; src, Mode mode, const TensorND dst) { switch (mode) { FOREACH(_cb_dispatch_mode) default: megdnn_throw("bad mode"); } } expands to: template <typename ctype> struct ElemwiseForwardImpl::ModeDispatcher<_arity, _dtype_cat, ctype> { static constexpr int arity = _arity; static void run( HandleImpl* handle, const ElemwiseOpParamN<arity> & amp; src, Mode mode, const TensorND dst) { switch (mode) { _cb_dispatch_mode(COND_LEQ_MOV) _cb_dispatch_mode(COND_LT_MOV) _cb_dispatch_mode(FUSE_MUL_ADD3) _cb_dispatch_mode (CLIP) _cb_dispatch_mode(PRELU_GRAD) default: megdnn_throw("bad mode"); } } #define _cb_dispatch_mode(_m) \ case Mode::_m: \ do { \ using KernImpl = ElemwiseKern<\ megcorePlatformCPU, param_enumv::Elemwise::Mode::_m, ctype>; \ MIDOUT_BEGIN(\ megdnn_naive_elemwise, \ midout_iv(param_enumv::Elemwise::Mode::_m)) { \ auto params = src; \ MEGDNN_DISPATCH_CPU_KERN(\ handle, ElemArithKernCaller<arity MEGDNN_COMMA KernImpl>::run( \ dst.ptr<ctype>(), params)); \ return; \ } \ MIDOUT_END(); \ } while (0);
Look at the definition of the CLIP operator of ElemwiseKern
struct ElemwiseKern; DEF_KERN_ALL(CLIP, x <= y ? y : (x <= z ? x : z)); //! define kernel for all ctypes #define DEF_KERN_ALL(_mode, _imp) \ DEF_KERN_INT(_mode, _imp); \ DEF_KERN_FLOAT(_mode, _imp); //! define kernel for all float types #define DEF_KERN_FLOAT(_mode, _imp) \ DEF_KERN(dt_float32, _mode, _imp); \ DNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \ DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);) //! define kernel for all int types #define DEF_KERN_INT(_mode, _imp) \ DEF_KERN(dt_int32, _mode, _imp); \ DEF_KERN(dt_int16, _mode, _imp); \ DEF_KERN(dt_int8, _mode, _imp); \ DEF_KERN(dt_uint8, _mode, _imp); //! define kernel for a single ctype #define DEF_KERN(_ctype, _mode, _imp) \ template <megcorePlatform_t plat> \ struct ElemwiseKern<plat, param_enumv::Elemwise::Mode::_mode, _ctype> { \ typedef _ctype ctype; \ static __host__ __device__ _ctype apply(KERN_SIG) { return ctype(_imp); } \ }
ElemArithKernCaller executes the code for the actual operator:
Among them, KernImpl is the ElemwiseKern defined earlier, through MEGDNN_DISPATCH_CPU_KERN, there is CPU scheduling execution
/*! * \brief operator impls should utilize this method to * \param _handle a pointer to HandleImpl * \param _stmt the statements to be executed for the kernel */ #define MEGDNN_DISPATCH_CPU_KERN(_handle, _stmt) \ do { \ auto _kern = [=]() { _stmt; }; \ _handle->dispatch_kern(_kern); \ } while (0) template <int arity, class KernImpl> struct ElemArithKernCaller { typedef typename KernImpl::ctype ctype; static void run(ctype* dest, const ElemwiseOpParamN<arity> & amp; param); }; template <class KernImpl> struct ElemArithKernCaller<1, KernImpl> { typedef typename KernImpl::ctype ctype; static void run(ctype* dest, const ElemwiseOpParamN<1> & amp; param) { auto iter0 = tensor_iter_valonly<ctype>(param[0]).begin(); for (size_t i = 0; i < param.size; + + i) { dest[i] = KernImpl::apply(*iter0); + + iter0; } } }; template <class KernImpl> struct ElemArithKernCaller<2, KernImpl> { typedef typename KernImpl::ctype ctype; static void run(ctype* dest, const ElemwiseOpParamN<2> & amp; param) { auto iter0 = tensor_iter_valonly<ctype>(param[0]).begin(); auto iter1 = tensor_iter_valonly<ctype>(param[1]).begin(); for (size_t i = 0; i < param.size; + + i) { dest[i] = KernImpl::apply(*iter0, *iter1); + + iter0; + + iter1; } } }; template <class KernImpl> struct ElemArithKernCaller<3, KernImpl> { typedef typename KernImpl::ctype ctype; static void run(ctype* dest, const ElemwiseOpParamN<3> & amp; param) { auto iter0 = tensor_iter_valonly<ctype>(param[0]).begin(); auto iter1 = tensor_iter_valonly<ctype>(param[1]).begin(); auto iter2 = tensor_iter_valonly<ctype>(param[2]).begin(); for (size_t i = 0; i < param.size; + + i) { dest[i] = KernImpl::apply(*iter0, *iter1, *iter2); + + iter0; + + iter1; + + iter2; } } };
The knowledge points of the article match the official knowledge files, and you can further learn relevant knowledge