Boost logo

Boost Users :

From: Hicham Mouline (hicham_at_[hidden])
Date: 2008-05-03 12:43:03


Hello,

I have written a small set of templates to represent a 1-variable real
mathematical function, with the help of "expression templates",
mentioned in blitz++ for arrays, and a simple metafunction to calculate the
symbolic derivative.

I feel I have reinvented the wheel but I don't know which libs from boost to
reuse (MPL, function, bind...), or POOMA (the expr templates part of it).

I have written all the templates as purely compile-time objects with all the
function data stored in types. The main drawback is the inability
to use floating-point literals.

The objective is to store in compile-time arbitrary 1-var real functions
like:
   f(x) = x^2 + sqrt(x - 3)
and then to have a metafunction to calculate f'(x), the derivative.
(Partials derivatives for later)...

Here is what I came up with. Not sure my nomenclature is correct, I came
across the terms in misc lib docs:

The parse tree is:
A "Function" can be a "Terminal", an "Application Expression", a "Unary
expression" or a "Binary Expression" (The interface is inadequate because
the user instantiates a new type (a new symbolic function) very verbosely).

A "Terminal" can be a "Literal", a "Variable" or a "Elementary function".
A "Literal" is a class template whose template argument is an int (drawback:
cannot store floating points this way, and I should have ).
A "Variable" just represents "x" in the usual notation "f(x)"
A "Elementary Function" is a wrapper around <cmath>'s standard and perhaps
boost::math's special 1-variable functions, like sin, cos, exp, ....

"Application Expression" is to represent sin ( x - 2 ), sort of like
callable type.
"Unary Expression" is to represent op R (I have 2 so far: +R or -R)
"Binary Expression" is to represent L op R (L / R)

Op is an "Operator" (plus, minus, multiplies, divides, power)...

As I said, I feel this must have been done already?
Is there a library I can reuse?

Is there any way to represent floating points while staying totally
compile-time?

Below is the code with an example call:

Best regards,

----------------------------------------------------------------------------
-------
#include <iostream>
#include <typeinfo>
#include <cmath>
#include <boost/type_traits/is_same.hpp>
#include <boost/static_assert.hpp>

namespace MathFunction {

template <typename T>
struct FunctionTag {
        typedef T function;
};

// Terminals
template <typename T>
struct TerminalTag : FunctionTag< T > {
        typedef T terminal;
};

template <int N>
struct Literal : TerminalTag< Literal<N> > {
        typedef Literal<0> derivative;
        static double eval(double variable)
        {
                return static_cast<double>(N);
        }
};

struct Variable : TerminalTag<Variable> {
        typedef Literal<1> derivative;
        static double eval(double variable)
        {
                return variable;
        }
};

template <typename OpTag, typename R>
struct UnaryExpression;

template <typename L, typename OpTag, typename R>
struct BinaryExpression;

struct plus;
struct minus;
struct multiplies;
struct divides;
struct power;

// Elementary functions
template <typename T>
struct ElemFunctionTag : FunctionTag< T > {
        typedef T elementaryfunction;
};

struct Cos;
struct Sin : ElemFunctionTag<Sin> {
        typedef Cos derivative;
        static double eval(double variable)
        {
                return std::sin(variable);
        }
};
struct Cos : ElemFunctionTag<Cos> {
        typedef UnaryExpression<minus,Sin> derivative;
};
struct Tan : ElemFunctionTag<Tan> {
};
struct ATan : ElemFunctionTag<ATan> {};
struct Exp : ElemFunctionTag<Exp> {};

// Operators
template <typename T>
struct UnaryOperatorTag {
        typedef T unaryoper;
};
template <typename T>
struct BinaryOperatorTag {
        typedef T binaryoper;
};

struct plus : UnaryOperatorTag<plus>, BinaryOperatorTag<plus> {
        static double eval(double var)
        {
                return var;
        }
        static double eval(double Lvar, double Rvar)
        {
                return Lvar + Rvar;
        }

        template <typename L, typename R>
        struct derivative {
                typedef BinaryExpression<L::derivative, plus, R::derivative>
type;
        };
};
struct minus : UnaryOperatorTag<minus>, BinaryOperatorTag<minus> {
        static double eval(double var)
        {
                return -var;
        }
        static double eval(double Lvar, double Rvar)
        {
                return Lvar - Rvar;
        }
        template <typename L, typename R>
        struct derivative {
                typedef BinaryExpression<L::derivative, minus,
R::derivative> type;
        };
};
struct multiplies : BinaryOperatorTag<multiplies> {
        static double eval(double Lvar, double Rvar)
        {
                return Lvar * Rvar;
        }
        template <typename L, typename R>
        struct derivative {
                typedef BinaryExpression<L::derivative, multiplies, R> left;
                typedef BinaryExpression<L, multiplies, R::derivative>
right;
                typedef BinaryExpression<left, plus, right> type;
        };
};
struct divides : BinaryOperatorTag<divides> {
        static double eval(double Lvar, double Rvar)
        {
                return Lvar/Rvar;
        }
        template <typename L, typename R>
        struct derivative {
                typedef BinaryExpression<L::derivative, multiplies, R> left;
                typedef BinaryExpression<L, multiplies, R::derivative>
right;
                typedef BinaryExpression<R, power, Literal<2> > bottom;
                typedef BinaryExpression<left, minus, right> top;
                typedef BinaryExpression<top, divides, bottom> type;
        };
};

struct power : BinaryOperatorTag<power> {
        static double eval(double Lvar, double Rvar)
        {
                return std::pow(Lvar,Rvar);
        }
        template <typename L, typename R>
        struct derivative {
                typedef BinaryExpression<R, minus, Literal<1> > Rminus1;
                typedef BinaryExpression<L, power, Rminus1> right;
                typedef BinaryExpression<R, multiplies, right> type;
        };
};
struct root : BinaryOperatorTag<root> {
        static double eval(double Lvar, double Rvar)
        {
                return std::pow(Rvar,1.0/Lvar);
        }
};

// Expression

template <typename OpTag, typename R>
struct UnaryExpression : FunctionTag< UnaryExpression<OpTag,R> > {
        BOOST_STATIC_ASSERT(( boost::is_same<OpTag::unaryoper,OpTag>::value
));
        BOOST_STATIC_ASSERT(( boost::is_same<R::function,R>::value ));
        typedef UnaryExpression<OpTag, R::derivative> derivative;
        static double eval(double variable)
        {
                return OpTag::eval(R::eval(variable));
        }
};

template <typename L, typename OpTag, typename R>
struct BinaryExpression : FunctionTag< BinaryExpression<L,OpTag,R> > {
        BOOST_STATIC_ASSERT(( boost::is_same<L::function,L>::value ));
        BOOST_STATIC_ASSERT(( boost::is_same<OpTag::binaryoper,OpTag>::value
));
        BOOST_STATIC_ASSERT(( boost::is_same<R::function,R>::value ));

        typedef OpTag::derivative<L,R>::type derivative;
        static double eval(double variable)
        {
                return OpTag::eval(L::eval(variable), R::eval(variable));
        }
};

template <typename L, typename R>
struct ApplicationExpression : FunctionTag< ApplicationExpression<L,R> > {
        BOOST_STATIC_ASSERT(( boost::is_same<L::elementaryfunction,L>::value
));
        BOOST_STATIC_ASSERT(( boost::is_same<R::function,R>::value ));
        typedef BinaryExpression< ApplicationExpression<L::derivative,R>,
multiplies, R::derivative > derivative;
        static double eval(double variable)
        {
                return L::eval( R::eval(variable) );
        }
};

template <typename F, int n=1>
struct Derivative {
        BOOST_STATIC_ASSERT(( n>0 ));
        typedef Derivative<F::derivative, n-1>::type type;
};

template <typename F>
struct Derivative<F,1> {
        typedef F::derivative type;
};

// Main interface
template <typename K>
struct Function {
        BOOST_STATIC_ASSERT(( boost::is_same<K::function,K>::value ));
        typedef Function<K::derivative> derivative;
        static double eval(double variable)
        {
                return K::eval(variable);
        }
};

template <int N>
std::ostream& operator<<( std::ostream& os, const Literal<N>& )
{
        os << N;
        return os;
}

std::ostream& operator<<( std::ostream& os, const Variable& )
{
        os << 'x';
        return os;
}

template <typename T>
std::ostream& operator<<( std::ostream& os, const ElemFunctionTag<T>& )
{
        os << T();
        return os;
}

std::ostream& operator<<( std::ostream& os, const Sin& )
{
        os << "sin x";
        return os;
}

std::ostream& operator<<( std::ostream& os, const Cos& )
{
        os << "cos x";
        return os;
}

template <typename L, typename R>
std::ostream& operator<<( std::ostream& os, const
ApplicationExpression<L,R>& )
{
        os << L() << '(' << R() << ')';
        return os;
}

template <typename OpTag, typename R>
std::ostream& operator<<( std::ostream& os, const UnaryExpression<OpTag,R>&
)
{
        os << OpTag() << R();
        return os;
}

template <typename L, typename OpTag, typename R>
std::ostream& operator<<( std::ostream& os, const
BinaryExpression<L,OpTag,R>& )
{
        os << '(' << L() << ' '<< OpTag() << ' ' << R() << ')';
        return os;
}

std::ostream& operator<<( std::ostream& os, const plus& )
{
        os << '+';
        return os;
}
std::ostream& operator<<( std::ostream& os, const minus& )
{
        os << '-';
        return os;
}
std::ostream& operator<<( std::ostream& os, const multiplies& )
{
        os << '*';
        return os;
}
std::ostream& operator<<( std::ostream& os, const divides& )
{
        os << '/';
        return os;
}
std::ostream& operator<<( std::ostream& os, const power& )
{
        os << '^';
        return os;
}

std::ostream& operator<<( std::ostream& os, const root& )
{
        os << 'V';
        return os;
}

template <typename K>
std::ostream& operator<<( std::ostream& os, const Function<K>& )
{
        os << "f(x) ="<< K();
        return os;
}

}

int main(int argc, char* argv[])
{
        using namespace MathFunction;
        typedef Function< Sin > f;
        typedef Derivative<f,4>::type fprime;
        std::cout<< f() <<std::endl;
        std::cout<< fprime() <<std::endl;
}


Boost-users list run by williamkempf at hotmail.com, kalb at libertysoft.com, bjorn.karlsson at readsoft.com, gregod at cs.rpi.edu, wekempf at cox.net