//go:build windows

package ociwclayer

import (
	"archive/tar"
	"bufio"
	"context"
	"errors"
	"io"
	"os"
	"path"
	"path/filepath"
	"strings"

	winio "github.com/Microsoft/go-winio"
	"github.com/Microsoft/go-winio/backuptar"
	"github.com/Microsoft/hcsshim/internal/wclayer"
)

const WhiteoutPrefix = ".wh."

var (
	// mutatedFiles is a list of files that are mutated by the import process
	// and must be backed up and restored.
	mutatedFiles = map[string]string{
		"UtilityVM/Files/EFI/Microsoft/Boot/BCD":      "bcd.bak",
		"UtilityVM/Files/EFI/Microsoft/Boot/BCD.LOG":  "bcd.log.bak",
		"UtilityVM/Files/EFI/Microsoft/Boot/BCD.LOG1": "bcd.log1.bak",
		"UtilityVM/Files/EFI/Microsoft/Boot/BCD.LOG2": "bcd.log2.bak",
	}
)

// ImportLayerFromTar  reads a layer from an OCI layer tar stream and extracts it to the
// specified path. The caller must specify the parent layers, if any, ordered
// from lowest to highest layer.
//
// The caller must ensure that the thread or process has acquired backup and
// restore privileges.
//
// This function returns the total size of the layer's files, in bytes.
func ImportLayerFromTar(ctx context.Context, r io.Reader, path string, parentLayerPaths []string) (int64, error) {
	err := os.MkdirAll(path, 0)
	if err != nil {
		return 0, err
	}
	w, err := wclayer.NewLayerWriter(ctx, path, parentLayerPaths)
	if err != nil {
		return 0, err
	}
	n, err := writeLayerFromTar(ctx, r, w, path)
	cerr := w.Close()
	if err != nil {
		return 0, err
	}
	if cerr != nil {
		return 0, cerr
	}
	return n, nil
}

func writeLayerFromTar(ctx context.Context, r io.Reader, w wclayer.LayerWriter, root string) (int64, error) {
	t := tar.NewReader(r)
	// CodeQL [SM03409] False positive, `internal/safefile` package ensures tar extractions are always
	// bound to the layer root directory.
	hdr, err := t.Next()
	totalSize := int64(0)
	buf := bufio.NewWriter(nil)
	for err == nil {
		select {
		case <-ctx.Done():
			return 0, ctx.Err()
		default:
		}

		base := path.Base(hdr.Name)
		if strings.HasPrefix(base, WhiteoutPrefix) {
			name := path.Join(path.Dir(hdr.Name), base[len(WhiteoutPrefix):])
			err = w.Remove(filepath.FromSlash(name))
			if err != nil {
				return 0, err
			}
			// CodeQL [SM03409] False positive, `internal/safefile` package ensures tar extractions are always
			// bound to the layer root directory.
			hdr, err = t.Next()
		} else if hdr.Typeflag == tar.TypeLink {
			err = w.AddLink(filepath.FromSlash(hdr.Name), filepath.FromSlash(hdr.Linkname))
			if err != nil {
				return 0, err
			}
			// CodeQL [SM03409] False positive, `internal/safefile` package ensures tar extractions are always
			// bound to the layer root directory.
			hdr, err = t.Next()
		} else {
			var (
				name     string
				size     int64
				fileInfo *winio.FileBasicInfo
			)
			name, size, fileInfo, err = backuptar.FileInfoFromHeader(hdr)
			if err != nil {
				return 0, err
			}
			err = w.Add(filepath.FromSlash(name), fileInfo)
			if err != nil {
				return 0, err
			}
			hdr, err = writeBackupStreamFromTarAndSaveMutatedFiles(buf, w, t, hdr, root)
			totalSize += size
		}
	}
	if !errors.Is(err, io.EOF) {
		return 0, err
	}
	return totalSize, nil
}

// writeBackupStreamFromTarAndSaveMutatedFiles reads data from a tar stream and
// writes it to a backup stream, and also saves any files that will be mutated
// by the import layer process to a backup location.
func writeBackupStreamFromTarAndSaveMutatedFiles(buf *bufio.Writer, w io.Writer, t *tar.Reader, hdr *tar.Header, root string) (nextHdr *tar.Header, err error) {
	var bcdBackup *os.File
	var bcdBackupWriter *winio.BackupFileWriter
	if backupPath, ok := mutatedFiles[hdr.Name]; ok {
		bcdBackup, err = os.Create(filepath.Join(root, backupPath))
		if err != nil {
			return nil, err
		}
		defer func() {
			cerr := bcdBackup.Close()
			if err == nil {
				err = cerr
			}
		}()

		bcdBackupWriter = winio.NewBackupFileWriter(bcdBackup, false)
		defer func() {
			cerr := bcdBackupWriter.Close()
			if err == nil {
				err = cerr
			}
		}()

		buf.Reset(io.MultiWriter(w, bcdBackupWriter))
	} else {
		buf.Reset(w)
	}

	defer func() {
		ferr := buf.Flush()
		if err == nil {
			err = ferr
		}
	}()

	return backuptar.WriteBackupStreamFromTarFile(buf, t, hdr)
}
