/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file intrin_rule_opencl.cc
 * \brief OpenCL intrinsic rules.
 */
#include <tvm/arith/analyzer.h>

#include "../intrin_rule.h"

namespace tvm {
namespace codegen {
namespace intrin {

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos").set_body(DispatchPureExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh").set_body(DispatchPureExtern<Direct>);

// There is no warp shuffle instruction in standard OpenCL
// When shuffle is used, we assume it is intel's shuffle extension
static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) {
  PrimExpr e = args[0];
  const CallNode* call = e.as<CallNode>();
  CHECK(call != nullptr);
  CHECK_EQ(call->args.size(), 5);  // mask, value, warp_id, width, warp_size
  arith::Analyzer analyzer;
  CHECK(analyzer.CanProve(call->args[3] == call->args[4]))
      << "Intel warp shuffle dose not support width != warp_size";
  Array<PrimExpr> opencl_args{{StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}};
  *rv = Call(call->dtype, builtin::call_pure_extern(), opencl_args);
}

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle").set_body(DispatchIntelShuffle);

}  // namespace intrin
}  // namespace codegen
}  // namespace tvm
