package lsp

import (
	"context"
	"fmt"
	"io"
	"net"
	"path/filepath"
	"slices"
	"testing"
	"time"

	"github.com/sourcegraph/jsonrpc2"

	"github.com/open-policy-agent/regal/internal/lsp/clients"
	"github.com/open-policy-agent/regal/internal/lsp/connection"
	"github.com/open-policy-agent/regal/internal/lsp/handler"
	"github.com/open-policy-agent/regal/internal/lsp/log"
	"github.com/open-policy-agent/regal/internal/lsp/types"
	"github.com/open-policy-agent/regal/internal/lsp/uri"
	"github.com/open-policy-agent/regal/internal/testutil"
	"github.com/open-policy-agent/regal/internal/util"
)

const (
	mainRegoFileName = "/main.rego"
	// defaultTimeout is set based on the investigation done as part of
	// https://github.com/open-policy-agent/regal/issues/931. 20 seconds is 10x the
	// maximum time observed for an operation to complete.
	defaultTimeout             = 20 * time.Second
	defaultBufferedChannelSize = 5
)

type messages map[string]chan []string

// determineTimeout returns a timeout duration based on whether
// the test suite is running with race detection, if so, a more permissive
// timeout is used.
func determineTimeout() time.Duration {
	if isRaceEnabled() {
		// based on the upper bound here, 20x slower
		// https://go.dev/doc/articles/race_detector#Runtime_Overheads
		return defaultTimeout * 20
	}

	return defaultTimeout
}

func createAndInitServer(t *testing.T, ctx context.Context, tempDir string, clientHandler connection.HandlerFunc) (
	*LanguageServer,
	*jsonrpc2.Conn,
) {
	t.Helper()

	// This is set due to eventing being so slow in go test -race that we
	// get flakes. TODO, work out how to avoid needing this in lsp tests.
	pollingInterval := time.Duration(0)
	if isRaceEnabled() {
		pollingInterval = 10 * time.Second
	}

	logger := log.NewLogger(log.LevelDebug, t.Output())

	// set up the server and client connections
	ls := NewLanguageServer(ctx, &LanguageServerOptions{Logger: logger, WorkspaceDiagnosticsPoll: pollingInterval})

	go ls.StartDiagnosticsWorker(ctx)
	go ls.StartConfigWorker(ctx)

	netConnServer, netConnClient := net.Pipe()

	connServer := connection.New(ctx, netConnServer, ls.Handle)
	connClient := connection.New(ctx, netConnClient, clientHandler)

	go func() {
		<-ctx.Done()
		// we need only close the pipe connections as the jsonrpc2.Conn accept the ctx
		_ = netConnClient.Close()
		_ = netConnServer.Close()
	}()

	ls.SetConn(connServer)

	// a blank tempDir means no workspace root was required.
	rootURI := ""
	if tempDir != "" {
		rootURI = uri.FromPath(clients.IdentifierGeneric, tempDir)
	}

	request := types.InitializeParams{RootURI: rootURI, ClientInfo: types.ClientInfo{Name: "go test"}}

	var response types.InitializeResult
	testutil.NoErr(connClient.Call(ctx, "initialize", request, &response))(t)

	// 2. Client sends initialized notification
	// no response to the call is expected
	testutil.NoErr(connClient.Call(ctx, "initialized", struct{}{}, nil))(t)

	return ls, connClient
}

func createPublishDiagnosticsHandler(t *testing.T, out io.Writer, messages messages) connection.HandlerFunc {
	t.Helper()

	return func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) {
		if req.Method != methodTdPublishDiagnostics {
			fmt.Fprintln(out, "createClientHandler: unexpected request method:", req.Method)

			return struct{}{}, nil
		}

		return handler.WithParams(req, func(params types.FileDiagnostics) (any, error) {
			violations := make([]string, len(params.Items))
			for i, item := range params.Items {
				violations[i] = item.Code
			}

			fileBase := filepath.Base(params.URI)
			fmt.Fprintln(out, "createPublishDiagnosticsHandler: queue", fileBase, len(messages[fileBase]))

			select {
			case messages[fileBase] <- util.Sorted(violations):
			case <-time.After(1 * time.Second):
				t.Fatalf("timeout writing to messages channel for %s", fileBase)
			}

			return struct{}{}, nil
		})
	}
}

func createMessageChannels(files map[string]string) messages {
	messages := make(messages, len(files))
	for _, file := range util.MapKeys(files, filepath.Base) {
		messages[file] = make(chan []string, 10)
	}

	return messages
}

func testRequestDataCodes(t *testing.T, requestData types.FileDiagnostics, fileURI string, codes []string) bool {
	t.Helper()

	if requestData.URI != fileURI {
		t.Log("expected diagnostics to be sent for", fileURI, "got", requestData.URI)

		return false
	}

	// Extract the codes from requestData.Items
	requestCodes := make([]string, len(requestData.Items))
	for i, item := range requestData.Items {
		requestCodes[i] = item.Code
	}

	// Sort both slices
	slices.Sort(requestCodes)
	slices.Sort(codes)

	if !slices.Equal(requestCodes, codes) {
		t.Logf("waiting for items: %v, got: %v", codes, requestCodes)

		return false
	}

	t.Logf("got expected items")

	return true
}
