#include "audio_encoder.h"
#include "shared/ffmpeg_raii.h"

extern "C" {
#include <libavcodec/codec.h>
#include <libavcodec/codec_par.h>
#include <libavcodec/avcodec.h>
#include <libavformat/avformat.h>
#include <libswresample/swresample.h>
#include <libavutil/channel_layout.h>
#include <libavutil/error.h>
#include <libavutil/frame.h>
#include <libavutil/mem.h>
#include <libavutil/rational.h>
#include <libavutil/samplefmt.h>
}

#include <assert.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <memory>
#include <string>
#include <vector>

#include "shared/mux.h"
#include "shared/shared_defs.h"
#include "shared/timebase.h"

using namespace std;

AudioEncoder::AudioEncoder(const string &codec_name, int bit_rate, const AVOutputFormat *oformat)
{
	const AVCodec *codec = avcodec_find_encoder_by_name(codec_name.c_str());
	if (codec == nullptr) {
		fprintf(stderr, "ERROR: Could not find codec '%s'\n", codec_name.c_str());
		abort();
	}

	ctx = avcodec_alloc_context3(codec);
	ctx->bit_rate = bit_rate;
	ctx->sample_rate = OUTPUT_FREQUENCY;
	ctx->sample_fmt = codec->sample_fmts[0];
	ctx->ch_layout.order = AV_CHANNEL_ORDER_NATIVE;
	ctx->ch_layout.nb_channels = 2;
	ctx->ch_layout.u.mask = AV_CH_LAYOUT_STEREO;
	ctx->time_base = AVRational{1, TIMEBASE};
	if (oformat->flags & AVFMT_GLOBALHEADER) {
		ctx->flags |= AV_CODEC_FLAG_GLOBAL_HEADER;
	}
	if (avcodec_open2(ctx, codec, NULL) < 0) {
		fprintf(stderr, "Could not open codec '%s'\n", codec_name.c_str());
		abort();
	}

	resampler = nullptr;
	int ok = swr_alloc_set_opts2(&resampler,
	                             /*out_ch_layout=*/&ctx->ch_layout,
	                             /*out_sample_fmt=*/ctx->sample_fmt,
	                             /*out_sample_rate=*/OUTPUT_FREQUENCY,
	                             /*in_ch_layout=*/&ctx->ch_layout,
	                             /*in_sample_fmt=*/AV_SAMPLE_FMT_FLT,
	                             /*in_sample_rate=*/OUTPUT_FREQUENCY,
	                             /*log_offset=*/0,
	                             /*log_ctx=*/nullptr);
	if (ok != 0) {
		fprintf(stderr, "Allocating resampler failed.\n");
		abort();
	}

	if (swr_init(resampler) < 0) {
		fprintf(stderr, "Could not open resample context.\n");
		abort();
	}

	audio_frame = av_frame_alloc();
}

AudioEncoder::~AudioEncoder()
{
	av_frame_free(&audio_frame);
	swr_free(&resampler);
	avcodec_free_context(&ctx);
}

void AudioEncoder::encode_audio(const vector<float> &audio, int64_t audio_pts)
{
	if (ctx->frame_size == 0) {
		// No queueing needed.
		assert(audio_queue.empty());
		assert(audio.size() % 2 == 0);
		encode_audio_one_frame(&audio[0], audio.size() / 2, audio_pts);
		return;
	}

	int64_t sample_offset = audio_queue.size();

	audio_queue.insert(audio_queue.end(), audio.begin(), audio.end());
	size_t sample_num;
	for (sample_num = 0;
	     sample_num + ctx->frame_size * 2 <= audio_queue.size();
	     sample_num += ctx->frame_size * 2) {
		int64_t adjusted_audio_pts = audio_pts + (int64_t(sample_num) - sample_offset) * TIMEBASE / (OUTPUT_FREQUENCY * 2);
		encode_audio_one_frame(&audio_queue[sample_num],
		                       ctx->frame_size,
		                       adjusted_audio_pts);
	}
	audio_queue.erase(audio_queue.begin(), audio_queue.begin() + sample_num);

	last_pts = audio_pts + audio.size() * TIMEBASE / (OUTPUT_FREQUENCY * 2);
}

void AudioEncoder::encode_audio_one_frame(const float *audio, size_t num_samples, int64_t audio_pts)
{
	audio_frame->pts = audio_pts;
	audio_frame->nb_samples = num_samples;
	audio_frame->ch_layout.order = AV_CHANNEL_ORDER_NATIVE;
	audio_frame->ch_layout.nb_channels = 2;
	audio_frame->ch_layout.u.mask = AV_CH_LAYOUT_STEREO;
	audio_frame->format = ctx->sample_fmt;
	audio_frame->sample_rate = OUTPUT_FREQUENCY;

	if (av_samples_alloc(audio_frame->data, nullptr, 2, num_samples, ctx->sample_fmt, 0) < 0) {
		fprintf(stderr, "Could not allocate %zu samples.\n", num_samples);
		abort();
	}

	if (swr_convert(resampler, audio_frame->data, num_samples, reinterpret_cast<const uint8_t **>(&audio), num_samples) < 0) {
		fprintf(stderr, "Audio conversion failed.\n");
		abort();
	}

	int err = avcodec_send_frame(ctx, audio_frame);
	if (err < 0) {
		fprintf(stderr, "avcodec_send_frame() failed with error %d\n", err);
		abort();
	}

	for ( ;; ) {  // Termination condition within loop.
		AVPacketWithDeleter pkt = av_packet_alloc_unique();
		pkt->data = nullptr;
		pkt->size = 0;
		int err = avcodec_receive_packet(ctx, pkt.get());
		if (err == 0) {
			pkt->stream_index = 1;
			pkt->flags = 0;
			for (Mux *mux : muxes) {
				mux->add_packet(*pkt, pkt->pts, pkt->dts);
			}
		} else if (err == AVERROR(EAGAIN)) {
			break;
		} else {
			fprintf(stderr, "avcodec_receive_frame() failed with error %d\n", err);
			abort();
		}
	}

	av_freep(&audio_frame->data[0]);
	av_frame_unref(audio_frame);
}

void AudioEncoder::encode_last_audio()
{
	if (!audio_queue.empty()) {
		// Last frame can be whatever size we want.
		assert(audio_queue.size() % 2 == 0);
		encode_audio_one_frame(&audio_queue[0], audio_queue.size() / 2, last_pts);
		audio_queue.clear();
	}

	if (ctx->codec->capabilities & AV_CODEC_CAP_DELAY) {
		// Collect any delayed frames.
		for ( ;; ) {
			AVPacketWithDeleter pkt = av_packet_alloc_unique();
			pkt->data = nullptr;
			pkt->size = 0;
			int err = avcodec_receive_packet(ctx, pkt.get());
			if (err == 0) {
				pkt->stream_index = 1;
				pkt->flags = 0;
				for (Mux *mux : muxes) {
					mux->add_packet(*pkt, pkt->pts, pkt->dts);
				}
			} else if (err == AVERROR_EOF) {
				break;
			} else {
				fprintf(stderr, "avcodec_receive_frame() failed with error %d\n", err);
				abort();
			}
		}
	}
}

AVCodecParametersWithDeleter AudioEncoder::get_codec_parameters()
{
	AVCodecParameters *codecpar = avcodec_parameters_alloc();
	avcodec_parameters_from_context(codecpar, ctx);
	return AVCodecParametersWithDeleter(codecpar);
}
