#!/usr/pkg/bin/python3.12

import getopt, sys, os


# Find the location of the files
currentpath = os.path.dirname(os.path.abspath(__file__))

sphinxbinpath = os.path.abspath(currentpath + "/../libexec/sphinxtrain")
if os.path.exists(currentpath + "/../bin/Release"):
    sphinxbinpath = os.path.abspath(currentpath + "/../bin/Release")

sphinxpath = os.path.abspath(currentpath + "/../lib/sphinxtrain")
if os.path.exists(currentpath + "/../scripts/00.verify"):
    sphinxpath = os.path.abspath(currentpath + "/..")


# Perl script want forward slashes
training_basedir = os.getcwd().replace('\\', '/');
sphinxpath = sphinxpath.replace('\\','/')
sphinxbinpath = sphinxbinpath.replace('\\','/')

print "Sphinxtrain path:", sphinxpath
print "Sphinxtrain binaries path:", sphinxbinpath

#Proceed
def setup(task):
    if not os.path.exists("etc"):
        os.mkdir("etc")

    print "Setting up the database " + task

    out_cfg = open("etc/sphinx_train.cfg", "w")
    for line in open(sphinxpath + "/etc/sphinx_train.cfg", "r"):
        line = line.replace("___DB_NAME___", task)
        line = line.replace("___BASE_DIR___", training_basedir)
        line = line.replace("___SPHINXTRAIN_DIR___", sphinxpath)
        line = line.replace("___SPHINXTRAIN_BIN_DIR___", sphinxbinpath)
        out_cfg.write(line)
    out_cfg.close()

    out_cfg = open("etc/feat.params", "w")
    for line in open(sphinxpath + "/etc/feat.params", "r"):
        out_cfg.write(line)
    out_cfg.close()

steps = [
"000.comp_feat/slave_feat.pl",
"00.verify/verify_all.pl",
"0000.g2p_train/g2p_train.pl",
"01.lda_train/slave_lda.pl",
"02.mllt_train/slave_mllt.pl",
"05.vector_quantize/slave.VQ.pl",
"10.falign_ci_hmm/slave_convg.pl",
"11.force_align/slave_align.pl",
"12.vtln_align/slave_align.pl",
"20.ci_hmm/slave_convg.pl",
"30.cd_hmm_untied/slave_convg.pl",
"40.buildtrees/slave.treebuilder.pl",
"45.prunetree/slave.state-tying.pl",
"50.cd_hmm_tied/slave_convg.pl",
"60.lattice_generation/slave_genlat.pl",
"61.lattice_pruning/slave_prune.pl",
"62.lattice_conversion/slave_conv.pl",
"65.mmie_train/slave_convg.pl",
"90.deleted_interpolation/deleted_interpolation.pl",
"decode/slave.pl"
]

def run_stages(stages):
    for stage in stages.split(","):
        for step in steps:
                name = step.split("/")[0].split(".")[-1]
                if name == stage:
                    os.system(sphinxpath + "/scripts/" + step)

def run():
    print "Running the training"
    for step in steps:
        os.system(sphinxpath + "/scripts/" + step)

def usage():
    print ""
    print "Sphinxtrain processes the audio files and creates and acoustic model "
    print "for CMUSphinx toolkit. The data needs to have a certain layout "
    print "See the tutorial http://cmusphinx.sourceforge.net/wiki/tutorialam "
    print "for details"
    print ""
    print "Usage: sphinxtrain [options] <command>"
    print ""
    print "Commands:"
    print "     -t <task> setup - copy configuration into database"
    print "     [-s <stage1,stage2,stage3>] run - run the training or just selected stages"

def main():

    try:
        opts, args = getopt.getopt(sys.argv[1:], "ht:s:", ["help", "task", "stages"])
    except getopt.GetoptError, err:
        print str(err)
        usage()
        sys.exit(-1)

    task = None
    stages = None

    for o, a in opts:
        if o in ("-t", "--task"):
            task = a
        if o in ("-s", "--stages"):
            stages = a
        if o in ("-h", "--help"):
                usage()

    if len(args) == 0:
        usage()
        sys.exit(-1)

    command = args[0]

    if command == "setup":
        if task == None:
            print "No task name defined"
            sys.exit(-1)        
        setup(task)
    elif command == "run":
        if stages == None:
            run()
        else:
            run_stages(stages)
    else:
        run()

if __name__ == "__main__":
    main()
