// SignalLinkTests.cpp

// Copyright Stjepan Rajko 2007. Use, modification and
// distribution is subject to the Boost Software License, Version
// 1.0. (See accompanying file LICENSE_1_0.txt or copy at
// http://www.boost.org/LICENSE_1_0.txt)

#include "UtilityTests.h"
#include <Utility/signal_link.hpp>
#include <boost/optional/optional.hpp>

using namespace ame;
using namespace boost;

class SignalVoid : public signal_link<SignalVoid, void ()>
{
public:
	void Bang()
	{
		out(); // send out the void() signal when Bang is called
	}
}; // end class SignalVoid

class SignalVoidCounter : public boost::signals::trackable
{
	int cnt;
public:
	SignalVoidCounter() : cnt(0) {}
	void operator()()
	{
		cnt++; // whenever a void() signal is received, increase the counter
	}
	int GetCount()
	{
		return cnt;
	}
}; // end class SignalVoidCounter

class SignalFloat : public signal_link<SignalFloat, void (float)>
{
	float val;
public:
	SignalFloat(float val) : val(val) {}
	void operator()()
	{
		out(val); // upon receiving a void() signal, send out the stored value
	}
}; // end class SignalFloat

class SignalFloatCollector : public boost::signals::trackable
{
	optional<float> last;
public:
	void operator()(float x)
	{
		last = x; // store the received value of the void(float) signal
	}
	optional<float> GetLast()
	{
		return last;
	}
}; // end class SignalFloatCollector

class SignalInt : public signal_link<SignalInt, void (int)>
{
	int val;
public:
	SignalInt(int val) : val(val) {}
	void operator()() {out(val);}
}; // end class SignalInt

class SignalIntFloatCollector : public boost::signals::trackable
{
	optional<int> last_int;
	optional<float> last_float;
public:
	void operator()(int x)
	{
		last_int = x;
	}
	void operator()(float x)
	{
		last_float = x;
	}
	optional<int> GetLastInt()
	{
		return last_int;
	}
	optional<float> GetLastFloat()
	{
		return last_float;
	}
}; // end class SignalIntFloatCollector

class SignalFloatDoubler : public signal_link<SignalFloatDoubler, void (float, float)>
{
public:
	void operator()(float val) {out(val, val*2);}
};

class SignalFloatDuplicator : public signal_link<SignalFloatDuplicator, void (float, float)>
{
public:
	void operator()(float val) {out(val, val);}
};

class SignalFloat2Collector : public boost::signals::trackable
{
	optional<float> last1, last2;
public:
	void operator()(float val1, float val2)
	{
		last1 = val1;
		last2 = val2;
	}
	optional<float> GetLast1()
	{
		return last1;
	}
	optional<float> GetLast2()
	{
		return last2;
	}
};

void simple_test()
{
	SignalVoid banger;
	SignalVoidCounter counter;

	banger >>= counter; // this connects banger to counter

	banger.Bang(); // banger will now output a signal, and
	BOOST_CHECK(counter.GetCount() == 1); // counter will count it
	
	SignalFloat floater(2.5f);
	SignalFloatCollector collector;

	banger >>= floater >>= collector; // banger is now also connected to floater

	banger.Bang(); // signal from banger will now
	BOOST_CHECK(counter.GetCount() == 2); // increase the counter count
	BOOST_CHECK(collector.GetLast() == optional<float>(2.5f)); // and cause floater to output 2.5
} // end void simple_test()

void branching_test()
{
	SignalVoid banger;
	SignalVoidCounter counter;
	SignalFloat floater(2.5f);
	SignalFloatCollector collector;
	
	banger
		>= (floater >>= collector) // floater connects to collector, banger to floater
		>= counter; // and banger to counter
		
	banger.Bang();
	BOOST_CHECK(counter.GetCount() == 1);
	BOOST_CHECK(collector.GetLast() == optional<float>(2.5f));
} // end void branching_test()

void disconnect_test()
{
	SignalVoid banger;
	{
		SignalVoidCounter counter;
		SignalFloat floater(2.5f);
		SignalFloatCollector collector;

		banger
			>= counter
			>= (floater >>= collector);

		banger.Bang();
		BOOST_CHECK(counter.GetCount() == 1);
		BOOST_CHECK(collector.GetLast() == optional<float>(2.5f));
	} // counter, floater, and collector are now gone and disconnected
	BOOST_CHECK(banger.default_signal().num_slots() == 0); 

	SignalVoidCounter counter;

	banger >>= counter;
	banger.disconnect_all_slots();

	banger.Bang();
	BOOST_CHECK(counter.GetCount() == 0);
} // end void disconnect_test

void multi_type_test()
{
	SignalVoid banger;
	SignalInt inter(2);
	SignalFloat floater(3.3f);
	SignalIntFloatCollector collector;

	banger
		>= (inter >>= collector)
		>= (floater >>= collector);

	banger.Bang();
	BOOST_CHECK(collector.GetLastInt() == optional<int>(2));
	BOOST_CHECK(collector.GetLastFloat() == optional<float>(3.3f));
} // end void multi_type_test()

class SignalMultiCollector : public boost::signals::trackable
{
	optional<float> last, last1, last2;
	int cnt;
public:
	SignalMultiCollector() : cnt(0) {}
	void operator()()
	{
		cnt++;
	}
	int GetCount()
	{
		return cnt;
	}
	void operator()(float val1, float val2)
	{
		last1 = val1;
		last2 = val2;
	}
	optional<float> GetLast1()
	{
		return last1;
	}
	optional<float> GetLast2()
	{
		return last2;
	}
	void operator()(float x)
	{
		last = x;
	}
	optional<float> GetLast()
	{
		return last;
	}
}; // end class SignalMultiCollector

void multi_num_args_test()
{
	SignalVoid banger;
	SignalFloat floater(2.5f);
	SignalFloatDuplicator duplicator;
	SignalMultiCollector collector;
	
	banger
		>= collector
		>=
		(floater
			>= collector
			>= (duplicator >>= collector));

	banger.Bang();
	BOOST_CHECK(collector.GetCount() == 1);
	BOOST_CHECK(collector.GetLast() == optional<float>(2.5f));
	BOOST_CHECK(collector.GetLast1() == optional<float>(2.5f));
	BOOST_CHECK(collector.GetLast2() == optional<float>(2.5f));
} // end void multi_num_args_test()


class SignalMultiInheritedCollector : public SignalVoidCounter, public SignalFloatCollector, public SignalFloat2Collector
{
};

void multi_num_args_inherited_test()
{
	SignalVoid banger;
	SignalFloat floater(2.5f);
	SignalFloatDuplicator duplicator;
	SignalMultiInheritedCollector collector;
	
	banger
		>= (SignalVoidCounter &) collector
		>=
		(floater
			>= (SignalFloatCollector &) collector
			>= (duplicator >>= (SignalFloat2Collector &) collector));

	banger.Bang();
	BOOST_CHECK(collector.GetCount() == 1);
	BOOST_CHECK(collector.GetLast() == optional<float>(2.5f));
	BOOST_CHECK(collector.GetLast1() == optional<float>(2.5f));
	BOOST_CHECK(collector.GetLast2() == optional<float>(2.5f));
} // end void multi_num_args_inherited_test()

class SignalOutIntFloat : public signal_link<SignalOutIntFloat, void (float)>
{
public:
	SignalOutIntFloat(float x) : x(x) {}
	void operator()()
	{
		out(x);
		out_int((int)x);
	}
	boost::signal<void (int)> out_int;
private:
	float x;
}; // end class SignalOutIntFloat

void multi_out_test()
{
	SignalOutIntFloat multi_out(2.5f);
	SignalIntFloatCollector collector;
	
	multi_out >= collector;
	multi_out.out_int >= collector;
	multi_out();
	
	BOOST_CHECK(collector.GetLastFloat() == optional<float>(2.5f));
	BOOST_CHECK(collector.GetLastInt() == optional<int>(2));
} // end void multi_out_test()

class Signal2VoidCounter : public SignalVoidCounter
{
public:
	SignalVoidCounter other;
}; // end class Signal2VoidCounter

class Signal2VoidInputs : public signal_link<Signal2VoidInputs, void(int)>
{
	int result;
public:
	Signal2VoidInputs() : result(0) {};
	void operator()()
	{
		result++;
		out(result);
	}
	void AltInput()
	{
		result+=10;
		out(result);
	}
	int GetResult()
	{
		return result;
	}
}; // end class Signal2VoidInputs

void multi_in_test()
{
	SignalVoid banger;
	Signal2VoidCounter counter;
	
	banger
		>= counter
		>= counter.other;
	
	banger.Bang();
	BOOST_CHECK(counter.GetCount() == 1);
	BOOST_CHECK(counter.other.GetCount() == 1);

	Signal2VoidInputs inputs;

	banger
		>= inputs
		>= slot_selector<void ()> (inputs, &Signal2VoidInputs::AltInput);

	banger.Bang();
	BOOST_CHECK(inputs.GetResult() == 11);
}; // end void multi_in_test

class TestReceive
{
	int x;
public:
	void operator()() {};
};

void test_recv()
{}

void signal_hpp_test(test_suite *test)
{
	boost::signal<void()> sig;
	TestReceive tst;
	sig.connect(boost::bind(&TestReceive::operator(), ref(tst)));
	sig.connect(&test_recv);
//	sig.disconnect(ref(tst));
//	sig.disconnect(&test_recv);

	test->add(BOOST_TEST_CASE(&simple_test));
	test->add(BOOST_TEST_CASE(&branching_test));
	test->add(BOOST_TEST_CASE(&disconnect_test));
	test->add(BOOST_TEST_CASE(&multi_type_test));
	test->add(BOOST_TEST_CASE(&multi_num_args_test));
	test->add(BOOST_TEST_CASE(&multi_num_args_inherited_test));
	test->add(BOOST_TEST_CASE(&multi_out_test));
	test->add(BOOST_TEST_CASE(&multi_in_test));
}