// discrete_distribution.hpp // // Copyright (c) 2009 // Steven Watanabe // // Distributed under the Boost Software License, Version 1.0. (See // accompanying file LICENSE_1_0.txt or copy at // http://www.boost.org/LICENSE_1_0.txt) #ifndef BOOST_RANDOM_DISCRETE_DISTRIBUTION_HPP_INCLUDED #define BOOST_RANDOM_DISCRETE_DISTRIBUTION_HPP_INCLUDED #include #include #include #include #include #include #include #include namespace boost { namespace random { template class discrete_distribution { public: typedef WeightType input_type; typedef IntType result_type; template discrete_distribution(Iter begin, Iter end) : weights(begin, end), data(weights.size()) { std::size_t size = weights.size(); //assert(size <= (std::numeric_limits::max)()); std::vector > below_average; std::vector > above_average; WeightType weight_sum = std::accumulate(weights.begin(), weights.end(), static_cast(0)); WeightType weight_average = weight_sum / size; for(std::size_t i = 0; i < size; ++i) { if(weights[i] < weight_average) { below_average.push_back(std::make_pair(weights[i] / weight_average, static_cast(i))); } else { above_average.push_back(std::make_pair(weights[i] / weight_average, static_cast(i))); } } std::vector >::iterator b_iter = below_average.begin(), b_end = below_average.end(), a_iter = above_average.begin(), a_end = above_average.end() ; while(b_iter != b_end && a_iter != a_end) { data[b_iter->second] = std::make_pair(b_iter->first, a_iter->second); a_iter->first -= (1 - b_iter->first); if(a_iter->first < 1) { *b_iter = *a_iter++; } else { ++b_iter; } } for(; b_iter != b_end; ++b_iter) { data[b_iter->second].first = 1; } for(; a_iter != a_end; ++a_iter) { data[a_iter->second].first = 1; } } template IntType operator()(Engine& eng) const { assert(!data.empty()); boost::variate_generator > real_gen(eng, boost::uniform_01()); WeightType test = real_gen() * data.size(); IntType result = static_cast(test); if(test - result < data[result].first) { return result; } else { return(data[result].second); } } result_type min BOOST_PREVENT_MACRO_SUBSTITUTION () const { return 0; } result_type max BOOST_PREVENT_MACRO_SUBSTITUTION () const { return static_cast(weights.size() - 1); } private: std::vector weights; std::vector > data; }; } } #endif