/*************************************************************************
 * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 *
 * See LICENSE.txt for license information
 ************************************************************************/

#include "strongstream.h"
#include "rocmwrap.h"
#include "checks.h"
#include "param.h"

#if CUDART_VERSION >= 13000
#define cudaStreamGetCaptureInfo_v3 hipStreamGetCaptureInfo
#define cudaGraphAddDependencies_v2 hipGraphAddDependencies
#define cudaStreamUpdateCaptureDependencies_v2 hipStreamUpdateCaptureDependencies
#endif

// Tracks the captured work a given graph captured identified by its graph id.
struct ncclStrongStreamCapture {
  struct ncclStrongStreamCapture* next;
  hipGraph_t graph;
  unsigned long long graphId;
  hipStream_t captureStream;
  void* acquiredBy;
};

////////////////////////////////////////////////////////////////////////////////

static ncclCudaContext* cxtListHead = nullptr;
static pthread_mutex_t cxtListLock = PTHREAD_MUTEX_INITIALIZER;

ncclResult_t ncclCudaContextTrack(struct ncclCudaContext** out) {
  ncclResult_t result = ncclSuccess;
  hipCtx_t hcontext;
  hipCtxGetCurrent(&hcontext);

  pthread_mutex_lock(&cxtListLock);
  struct ncclCudaContext* p = cxtListHead;
  while (1) {
    if (p == nullptr) {
      p = (struct ncclCudaContext*)calloc(1, sizeof(struct ncclCudaContext));
      p->refCount = 1;
      p->hcontext = hcontext;
      p->next = cxtListHead;
      cxtListHead = p;
      NCCLCHECKGOTO(ncclStrongStreamConstruct(&p->launchOrder), result, leave);
      break;
    }
    if (p->hcontext == hcontext) {
      p->refCount += 1;
      break;
    }
    p = p->next;
  }
leave:
  pthread_mutex_unlock(&cxtListLock);
  *out = p;
  return ncclSuccess;
}

void ncclCudaContextDrop(struct ncclCudaContext* cxt) {
  pthread_mutex_lock(&cxtListLock);
  if (0 == --cxt->refCount) {
    struct ncclCudaContext** pp = &cxtListHead;
    while (*pp != cxt) pp = &(*pp)->next;
    *pp = cxt->next; // remove from list
    // Destroy resources held in cxt
    ncclStrongStreamDestruct(&cxt->launchOrder);
    free(cxt);
  }
  pthread_mutex_unlock(&cxtListLock);
}

////////////////////////////////////////////////////////////////////////////////

ncclResult_t ncclCudaGetCapturingGraph(
    struct ncclCudaGraph* graph, hipStream_t stream
  ) {
#if ROCM_VERSION >= 60100
  hipStreamCaptureStatus status;
  CUDACHECK(hipStreamGetCaptureInfo_v2(stream, &status, &graph->graphId, &graph->graph, nullptr, nullptr));
  if (status != hipStreamCaptureStatusActive) {
    graph->origin = nullptr;
    graph->graph = nullptr;
    graph->graphId = ULLONG_MAX;
  } else {
    graph->origin = stream;
  }
#endif
  return ncclSuccess;
}

ncclResult_t ncclCudaGraphAddDestructor(struct ncclCudaGraph graph, hipHostFn_t fn, void* arg) {
  #if ROCM_VERSION >= 60100
    hipUserObject_t object;
    CUDACHECK(hipUserObjectCreate(
      &object, arg, fn, /*initialRefcount=*/1, hipUserObjectNoDestructorSync
    ));
    // Hand over ownership to CUDA Graph
    CUDACHECK(hipGraphRetainUserObject(graph.graph, object, 1, hipGraphUserObjectMove));
    return ncclSuccess;
  #else
    return ncclInvalidUsage;
  #endif
}

////////////////////////////////////////////////////////////////////////////////

ncclResult_t ncclStrongStreamConstruct(struct ncclStrongStream* ss) {
  CUDACHECK(hipStreamCreateWithFlags(&ss->liveStream, hipStreamNonBlocking));
  #if ROCM_VERSION >= 60100
    ss->everCaptured = false;
    ss->captureHead = nullptr;
    pthread_mutex_init(&ss->lock, nullptr);
    CUDACHECK(hipEventCreateWithFlags(&ss->serialEvent, hipEventDisableTiming));
  #endif
  return ncclSuccess;
}

ncclResult_t ncclStrongStreamDestruct(struct ncclStrongStream* ss) {
  CUDACHECK(hipStreamDestroy(ss->liveStream));
  #if ROCM_VERSION >= 60100
    struct ncclStrongStreamCapture* cap = ss->captureHead;
    while (cap) {
      struct ncclStrongStreamCapture* next = cap->next;
      CUDACHECK(hipStreamDestroy(cap->captureStream));
      free(cap);
      cap = next;
    }
    CUDACHECK(hipEventDestroy(ss->serialEvent));
    pthread_mutex_destroy(&ss->lock);
  #endif
  return ncclSuccess;
}

NCCL_PARAM(GraphMixingSupport, "GRAPH_MIXING_SUPPORT", 0)
NCCL_PARAM(LaunchRaceFatal, "LAUNCH_RACE_FATAL", 1);
constexpr char const* launchRaceFatalMsg = "Fatal: host threads racing to launch NCCL on same device.";

static __thread char threadIdMarker;
static void* localThreadId() { return &threadIdMarker; }

ncclResult_t ncclStrongStreamAcquire(
   struct ncclCudaGraph graph, struct ncclStrongStream* ss, bool concurrent,
   hipStream_t* workStream
  ) {
  #if ROCM_VERSION >= 60100
    bool mixing = ncclParamGraphMixingSupport();
    if (graph.graphId == ULLONG_MAX) {
      *workStream = ss->liveStream;
      ss->liveAcquiredBy = localThreadId();
      if (mixing && __atomic_load_n(&ss->everCaptured, __ATOMIC_RELAXED)) {
        CUDACHECK(hipStreamWaitEvent(ss->liveStream, ss->serialEvent, 0));
      }
    } else {
      bool firstCapture = !ss->everCaptured;
      __atomic_store_n(&ss->everCaptured, true, __ATOMIC_RELAXED);

      ncclResult_t ret = ncclSuccess;
      if (concurrent) pthread_mutex_lock(&ss->lock);

      // Look for capture in our list of active captures.
      struct ncclStrongStreamCapture** pcap = &ss->captureHead;
      struct ncclStrongStreamCapture* cap;
      struct ncclStrongStreamCapture* spare = nullptr;
      while (*pcap != nullptr) {
        cap = *pcap;
        if (cap->graphId == graph.graphId) { // Capture node already exists.
          *workStream = cap->captureStream;
          cap->acquiredBy = localThreadId();
          if (concurrent) pthread_mutex_unlock(&ss->lock);
          return ncclSuccess;
        } else {
          hipStreamCaptureStatus status;
          CUDACHECKGOTO(hipStreamIsCapturing(cap->captureStream, &status), ret, do_unlock);
          if (status == hipStreamCaptureStatusActive) {
            pcap = &cap->next; // Active capture doesn't match, on to next.
          } else { // Capture no longer active
            *pcap = cap->next; // Remove from current list
            if (spare == nullptr) { // Keep one spare to reuse below.
              spare = cap;
            } else {
              hipStreamDestroy(cap->captureStream);
              free(cap);
            }
          }
        }
      }
      // No matching capture, need a new entry.
      cap = spare;
      if (cap == nullptr) {
        cap = (struct ncclStrongStreamCapture*)calloc(1, sizeof(struct ncclStrongStreamCapture));
        CUDACHECKGOTO(hipStreamCreateWithFlags(&cap->captureStream, hipStreamNonBlocking), ret, do_unlock);
      }
      cap->graphId = graph.graphId;
      cap->acquiredBy = localThreadId();
      // Push to capturing list.
      cap->next = ss->captureHead;
      ss->captureHead = cap;

    do_unlock:
      if (concurrent) pthread_mutex_unlock(&ss->lock);
      if (ret != ncclSuccess) return ret;

      *workStream = cap->captureStream;

      // Bring captureStream into the graph but without any dependencies.
      hipEvent_t scratch;
      CUDACHECK(hipEventCreateWithFlags(&scratch, hipEventDisableTiming));
      CUDACHECK(hipEventRecord(scratch, graph.origin));
      CUDACHECK(hipStreamWaitEvent(cap->captureStream, scratch, 0));
      CUDACHECK(hipEventDestroy(scratch));
      #if CUDART_VERSION >= 13000
      CUDACHECK(cudaStreamUpdateCaptureDependencies_v2(cap->captureStream, nullptr, nullptr, 0, hipStreamSetCaptureDependencies));
      #else
      CUDACHECK(hipStreamUpdateCaptureDependencies(cap->captureStream, nullptr, 0, hipStreamSetCaptureDependencies));
      #endif

      if (mixing && firstCapture) {
        CUDACHECK(hipEventRecord(ss->serialEvent, ss->liveStream));
      }
      if (mixing) {
        // First dependency is to wait on serialEvent
        CUDACHECK(hipStreamWaitEvent(cap->captureStream, ss->serialEvent, 0));
      }
    }
  #endif
  return ncclSuccess;
}

ncclResult_t ncclStrongStreamAcquiredWorkStream(
    struct ncclCudaGraph graph, struct ncclStrongStream* ss, bool concurrent,
    hipStream_t* workStream
  ) {
  #if ROCM_VERSION >= 60100
    if (graph.graphId == ULLONG_MAX) {
      *workStream = ss->liveStream;
    } else {
      if (concurrent) pthread_mutex_lock(&ss->lock);
      struct ncclStrongStreamCapture* cap = ss->captureHead;
      while (cap->graphId != graph.graphId) cap = cap->next;
      *workStream = cap->captureStream;
      if (concurrent) pthread_mutex_unlock(&ss->lock);
    }
  #else
    *workStream = ss->liveStream
  #endif
  return ncclSuccess;
}

ncclResult_t ncclStrongStreamRelease(
    struct ncclCudaGraph graph, struct ncclStrongStream* ss, bool concurrent
  ) {
  #if ROCM_VERSION >= 60100
    bool mixing = ncclParamGraphMixingSupport();
    if (mixing) {
      if (graph.graphId == ULLONG_MAX) {
        if (__atomic_load_n(&ss->everCaptured, __ATOMIC_RELAXED)) {
          CUDACHECK(hipEventRecord(ss->serialEvent, ss->liveStream));
        }
        if (ss->liveAcquiredBy != localThreadId() && ncclParamLaunchRaceFatal()) {
          WARN("%s", launchRaceFatalMsg);
          return ncclInvalidUsage;
        }
      } else {
        if (concurrent) pthread_mutex_lock(&ss->lock);
        struct ncclStrongStreamCapture* cap = ss->captureHead;
        while (cap->graphId != graph.graphId) cap = cap->next;
        if (concurrent) pthread_mutex_unlock(&ss->lock);

        // Add event record node with dependencies added further down.
        hipGraphNode_t recordNode;
        CUDACHECK(hipGraphAddEventRecordNode(&recordNode, graph.graph, nullptr, 0, ss->serialEvent));

        // Get current nodes from work stream so we can add them as dependencies.
        hipStreamCaptureStatus status;
        hipGraphNode_t const* nodes;
        size_t count = 0;
        #if CUDART_VERSION >= 13000
        hipError_t res = hipStreamGetCaptureInfo_v3(cap->captureStream, &status, nullptr, nullptr, &nodes, nullptr, &count);
        #else
        hipError_t res = hipStreamGetCaptureInfo_v2(cap->captureStream, &status, nullptr, nullptr, &nodes, &count);
        #endif

        #if CUDART_VERSION >= 12030
        if (res == hipErrorLossyQuery) { // CUDA is telling us the dependencies have edge annotations.
          hipGraphEdgeData const* edges;
          CUDACHECK(cudaStreamGetCaptureInfo_v3(cap->captureStream, &status, nullptr, nullptr, &nodes, &edges, &count));
          for (int i=0; i < (int)count; i++) {
            CUDACHECK(cudaGraphAddDependencies_v2(graph.graph, &nodes[i], &recordNode, &edges[i], 1));
          }
        }
        #else
        if (false) {}
        #endif
        else {
          CUDACHECK(res /* = cudaStreamGetCaptureInfo_v2(...)*/);
          for (int i=0; i < (int)count; i++) {
          #if CUDART_VERSION >= 13000
            CUDACHECK(cudaGraphAddDependencies_v2(graph.graph, &nodes[i], &recordNode, nullptr, 1));
          #else
            CUDACHECK(hipGraphAddDependencies(graph.graph, &nodes[i], &recordNode, 1));
          #endif
          }
        }

	// Make every future operation captured on cap->captureStream depend on 'recordNode'.
        #if CUDART_VERSION >= 13000
        CUDACHECK(cudaStreamUpdateCaptureDependencies_v2(
                    cap->captureStream,
                    &recordNode,          /* dependencies                */
                    /*edges =*/ nullptr,  /* no edge annotations         */
                    1,                    /* count                       */
                    hipStreamSetCaptureDependencies));
        #else
        CUDACHECK(hipStreamUpdateCaptureDependencies(
                    cap->captureStream,
                    &recordNode,
                    1,
                    hipStreamSetCaptureDependencies));
        #endif

        if (cap->acquiredBy != localThreadId() && ncclParamLaunchRaceFatal()) {
          WARN("%s", launchRaceFatalMsg);
          return ncclInvalidUsage;
        }
      }
    }
  #endif
  return ncclSuccess;
}

ncclResult_t ncclStreamWaitStream(hipStream_t a, hipStream_t b, hipEvent_t scratchEvent) {
  CUDACHECK(hipEventRecord(scratchEvent, b));
  CUDACHECK(hipStreamWaitEvent(a, scratchEvent, 0));
  return ncclSuccess;
}

ncclResult_t ncclStreamAdvanceToEvent(struct ncclCudaGraph g, hipStream_t s, hipEvent_t e) {
  if (g.graphId == ULLONG_MAX) {
    CUDACHECK(hipStreamWaitEvent(s, e, 0));
  } else {
    hipStream_t tmp;
    CUDACHECK(hipStreamCreateWithFlags(&tmp, hipStreamNonBlocking));
    CUDACHECK(hipStreamWaitEvent(tmp, e, 0));

    hipStreamCaptureStatus status;
    hipGraphNode_t const* nodes;
    size_t count = 0;
    #if CUDART_VERSION >= 13000
    hipError_t res = hipStreamGetCaptureInfo_v3(tmp, &status, nullptr, nullptr, &nodes, nullptr, &count);
    #else
    hipError_t res = hipStreamGetCaptureInfo_v2(tmp, &status, nullptr, nullptr, &nodes, &count);
    #endif

    #if CUDART_VERSION >= 12030
    if (res == hipErrorLossyQuery) { // CUDA is telling us the dependencies have edge annotations.
      hipGraphEdgeData const* edges;
      CUDACHECK(cudaStreamGetCaptureInfo_v3(tmp, &status, nullptr, nullptr, &nodes, &edges, &count));
      CUDACHECK(cudaStreamUpdateCaptureDependencies_v2(s, (hipGraphNode_t*)nodes, edges, count, hipStreamSetCaptureDependencies));
    }
    #else
    if (false) {}
    #endif
    else {
      CUDACHECK(res /* = cudaStreamGetCaptureInfo_v2(...)*/);
    #if CUDART_VERSION >= 13000
      CUDACHECK(cudaStreamUpdateCaptureDependencies_v2(s, (hipGraphNode_t*)nodes, nullptr, count, hipStreamSetCaptureDependencies));
    #else
      CUDACHECK(hipStreamUpdateCaptureDependencies(s, (hipGraphNode_t*)nodes, count, hipStreamSetCaptureDependencies));
    #endif
    }

    CUDACHECK(hipStreamDestroy(tmp));
  }
  return ncclSuccess;
}

ncclResult_t ncclStrongStreamSynchronize(struct ncclStrongStream* ss) {
  #if ROCM_VERSION >= 60100
    CUDACHECK(hipStreamWaitEvent(ss->liveStream, ss->serialEvent, 0));
  #endif
  CUDACHECK(hipStreamSynchronize(ss->liveStream));
  return ncclSuccess;
}
