/*************************************************************************
 *
 *  The Contents of this file are made available subject to
 *  the terms of GNU Lesser General Public License Version 2.1.
 *
 *
 *    GNU Lesser General Public License Version 2.1
 *    =============================================
 *    Copyright 2005-2008 by Kohei Yoshida.
 *    1039 Kingsway Dr., Apex, NC 27502, USA
 *
 *    This library is free software; you can redistribute it and/or
 *    modify it under the terms of the GNU Lesser General Public
 *    License version 2.1, as published by the Free Software Foundation.
 *
 *    This library is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 *    Lesser General Public License for more details.
 *
 *    You should have received a copy of the GNU Lesser General Public
 *    License along with this library; if not, write to the Free Software
 *    Foundation, Inc., 59 Temple Place, Suite 330, Boston,
 *    MA  02111-1307  USA
 *
 ************************************************************************/

#include "numeric/funcobj.hxx"
#include <string>
#include <memory>
#include <vector>
#include <exception>
#include <cstdlib>

using namespace ::scsolver::numeric;
using namespace ::std;

class TestFailed : public ::std::exception
{
public:
    explicit TestFailed(const char* reason) :
        m_reason(reason)
    {
    }

    virtual ~TestFailed() throw()
    {
    }

    const char* what() const throw()
    {
        return m_reason.c_str();
    }

private:
    string m_reason;
};

class TestBaseFunc1 : public SimpleFuncObj
{
public:
    TestBaseFunc1() :
        SimpleFuncObj(3)
    {
    }

    virtual ~TestBaseFunc1()
    {
    }

    virtual double eval() const
    {
        double x1 = getVar(0), x2 = getVar(1), x3 = getVar(2);
        return x1 + x2 + x3;
    }

    /**
     * Return a display-friendly function string.
     */
    virtual const::std::string getFuncString() const
    {
        return string("x1 + x2 + x3");
    }
};

class TestBaseFunc2 : public SimpleFuncObj
{
public:
    TestBaseFunc2() :
        SimpleFuncObj(6)
    {
    }

    virtual ~TestBaseFunc2()
    {
    }

    virtual double eval() const
    {
        double x1 = getVar(0), x2 = getVar(1), x3 = getVar(2);
        double x4 = getVar(3), x5 = getVar(4), x6 = getVar(5);
        return x1 + x2 + x3 + x4 + x5 + x6;
    }

    /**
     * Return a display-friendly function string.
     */
    virtual const::std::string getFuncString() const
    {
        return string("x1 + x2 + x3 + x4 + x5 + x6");
    }
};

/** 
 * Generate a random number between 0 and 1.
 * 
 * @return double random number generated
 */
double getRandomNumber()
{
    double val = static_cast<double>(rand());
    return val/static_cast<double>(RAND_MAX);
}

void checkVarValues(const BaseFuncObj& rBaseFunc, size_t varIndex, double varIndexValue)
{
    fprintf(stdout, "  checking variable values...\n");

    size_t varCount = rBaseFunc.getVarCount();
    fprintf(stdout, "    (");
    for (size_t i = 0; i < varCount; ++i)
    {
        double varValue = rBaseFunc.getVar(i);
        if (i > 0)
            fprintf(stdout, ", ");
        fprintf(stdout, "%g", varValue);
        if (i == varIndex)
        {
            if (varValue != varIndexValue)
                throw TestFailed("variable value is incorrect");
        }
        else
        {
            if (varValue != 0.0)
                throw TestFailed("locked variable value is not zero");
        }
    }
    fprintf(stdout, ")\n");
}

void resetVarValues(BaseFuncObj& rFuncObj)
{
    size_t varCount = rFuncObj.getVarCount();
    vector<double> initVars(varCount); // this should initialize all values to 0.
    rFuncObj.setVars(initVars);

    // Check the initialization status.  All variables should be zero.
    for (size_t i = 0; i < varCount; ++i)
    {
        if (rFuncObj.getVar(i) != 0.0)
            throw TestFailed("initial variable must be zero");
    }
}

void checkVarRatio(BaseFuncObj& rFuncObj, const vector<double>& ratios)
{
    fprintf(stdout, "  checking ratio...\n");
    size_t varCount = rFuncObj.getVarCount();
    if (varCount < 2)
        return;

    double var1 = rFuncObj.getVar(0);
    double ratio1 = ratios.at(0);
    for (size_t i = 1; i < varCount; ++i)
    {
        double _var1 = rFuncObj.getVar(i);
        double _var2 = var1 * ratios.at(i)/ratio1;
        double delta = _var1/_var2 - 1.0;
        fprintf(stdout, "    var = %g vs %g \t (delta = %g)\n", 
                _var1, _var2, delta);
        if ((delta > 0 ? delta : -delta)  > 5.0e-16)
            throw TestFailed("ratio is incorrect");
    }
}

void runTest(BaseFuncObj* p)
{
    auto_ptr<BaseFuncObj> pFuncObj(p);
    size_t varCount = pFuncObj->getVarCount();

    fprintf(stdout, "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n");
    fprintf(stdout, "running test on function: %s (variable count: %d)\n", 
            pFuncObj->getFuncString().c_str(), pFuncObj->getVarCount());
    fprintf(stdout, "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n");

    // Test the locked-by-index version.
    for (size_t i = 0; i < varCount; ++i)
    {
        resetVarValues(*pFuncObj);
        SingleVarFuncObj& rSingleVarFunc = pFuncObj->getSingleVarFuncObj(i);
        for (size_t incStep = 0; incStep < 20; ++incStep)
        {
            double newVar = rSingleVarFunc.getVar() + 3.0;
            rSingleVarFunc.setVar(newVar);
            checkVarValues(*pFuncObj, i, newVar);
        }
    }

    // Test the locked-by-ratio version.
    resetVarValues(*pFuncObj);
    vector<double> ratios;
    for (size_t i = 0; i < varCount; ++i)
        ratios.push_back(getRandomNumber());
    SingleVarFuncObj& rSingleVarFunc = pFuncObj->getSingleVarFuncObjByRatio(ratios);
    for (size_t incStep = 0; incStep < 20; ++incStep)
    {
        double newVar = rSingleVarFunc.getVar() + 3.0;
        rSingleVarFunc.setVar(newVar);
        checkVarRatio(*pFuncObj, ratios);
    }
}

int main()
{
    try
    {
        runTest(new TestBaseFunc1);
        runTest(new TestBaseFunc2);
        fprintf(stdout, "test successful\n");
    }
    catch (const ::std::exception& e)
    {
        fprintf(stdout, "%s\n", e.what());
    }
}
