You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

183 lines
5.2KB

  1. // Copyright 2015 Olivier Gillet.
  2. //
  3. // Author: Olivier Gillet (ol.gillet@gmail.com)
  4. //
  5. // Permission is hereby granted, free of charge, to any person obtaining a copy
  6. // of this software and associated documentation files (the "Software"), to deal
  7. // in the Software without restriction, including without limitation the rights
  8. // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. // copies of the Software, and to permit persons to whom the Software is
  10. // furnished to do so, subject to the following conditions:
  11. //
  12. // The above copyright notice and this permission notice shall be included in
  13. // all copies or substantial portions of the Software.
  14. //
  15. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17. // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  18. // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19. // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  20. // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  21. // THE SOFTWARE.
  22. //
  23. // See http://creativecommons.org/licenses/MIT/ for more information.
  24. //
  25. // -----------------------------------------------------------------------------
  26. //
  27. // Generates samples from various kinds of random distributions.
  28. #ifndef MARBLES_RANDOM_DISTRIBUTIONS_H_
  29. #define MARBLES_RANDOM_DISTRIBUTIONS_H_
  30. #include "stmlib/stmlib.h"
  31. #include <algorithm>
  32. #include "stmlib/dsp/dsp.h"
  33. #include "marbles/resources.h"
  34. namespace marbles {
  35. const size_t kNumBiasValues = 5;
  36. const size_t kNumRangeValues = 9;
  37. const float kIcdfTableSize = 128.0f;
  38. // Generates samples from beta distribution, from uniformly distributed samples.
  39. // For higher throughput, uses pre-computed tables of inverse cdfs.
  40. inline float BetaDistributionSample(float uniform, float spread, float bias) {
  41. // Tables are pre-computed only for bias <= 0.5. For values above 0.5,
  42. // symmetry is used.
  43. bool flip_result = bias > 0.5f;
  44. if (flip_result) {
  45. uniform = 1.0f - uniform;
  46. bias = 1.0f - bias;
  47. }
  48. bias *= (static_cast<float>(kNumBiasValues) - 1.0f) * 2.0f;
  49. spread *= (static_cast<float>(kNumRangeValues) - 1.0f);
  50. MAKE_INTEGRAL_FRACTIONAL(bias);
  51. MAKE_INTEGRAL_FRACTIONAL(spread);
  52. size_t cell = bias_integral * (kNumRangeValues + 1) + spread_integral;
  53. // Lower 5% and 95% percentiles use a different table with higher resolution.
  54. size_t offset = 0;
  55. if (uniform <= 0.05f) {
  56. offset = kIcdfTableSize + 1;
  57. uniform *= 20.0f;
  58. } else if (uniform >= 0.95f) {
  59. offset = 2 * (kIcdfTableSize + 1);
  60. uniform = (uniform - 0.95f) * 20.0f;
  61. }
  62. float x1y1 = stmlib::Interpolate(
  63. distributions_table[cell] + offset,
  64. uniform,
  65. kIcdfTableSize);
  66. float x2y1 = stmlib::Interpolate(
  67. distributions_table[cell + 1] + offset,
  68. uniform,
  69. kIcdfTableSize);
  70. float x1y2 = stmlib::Interpolate(
  71. distributions_table[cell + kNumRangeValues + 1] + offset,
  72. uniform,
  73. kIcdfTableSize);
  74. float x2y2 = stmlib::Interpolate(
  75. distributions_table[cell + kNumRangeValues + 2] + offset,
  76. uniform,
  77. kIcdfTableSize);
  78. float y1 = x1y1 + (x2y1 - x1y1) * spread_fractional;
  79. float y2 = x1y2 + (x2y2 - x1y2) * spread_fractional;
  80. float y = y1 + (y2 - y1) * bias_fractional;
  81. if (flip_result) {
  82. y = 1.0f - y;
  83. }
  84. return y;
  85. }
  86. // Pre-computed beta(3, 3) with a fatter tail.
  87. inline float FastBetaDistributionSample(float uniform) {
  88. return stmlib::Interpolate(dist_icdf_4_3, uniform, kIcdfTableSize);
  89. }
  90. // Draws samples from a discrete distribution. Used for the quantizer.
  91. // Example:
  92. // * 1 with probability 0.2
  93. // * 20 with probability 0.7
  94. // * 666 with probability 0.1
  95. //
  96. // DiscreteDistribution d;
  97. // d.Init();
  98. // d.AddToken(1, 0.2);
  99. // d.AddToken(20, 0.7);
  100. // d.AddToken(666, 0.1);
  101. // d.NoMoreTokens();
  102. // Result r = d.Sample(u);
  103. // cout << r.token_id;
  104. //
  105. // Weights do not have to add to 1.0f - the class handles normalization.
  106. //
  107. template<size_t size>
  108. class DiscreteDistribution {
  109. public:
  110. DiscreteDistribution() { }
  111. ~DiscreteDistribution() { }
  112. void Init() {
  113. sum_ = 0.0f;
  114. num_tokens_ = 1;
  115. cdf_[0] = 0.0f;
  116. token_ids_[0] = 0;
  117. }
  118. void AddToken(int token_id, float weight) {
  119. if (weight <= 0.0f) {
  120. return;
  121. }
  122. sum_ += weight;
  123. token_ids_[num_tokens_] = token_id;
  124. cdf_[num_tokens_] = sum_;
  125. ++num_tokens_;
  126. }
  127. void NoMoreTokens() {
  128. token_ids_[num_tokens_] = token_ids_[num_tokens_ - 1];
  129. cdf_[num_tokens_] = sum_ + 1.0f;
  130. }
  131. struct Result {
  132. int token_id;
  133. float fraction;
  134. float start;
  135. float width;
  136. };
  137. inline Result Sample(float u) const {
  138. Result r;
  139. u *= sum_;
  140. int n = std::upper_bound(&cdf_[1], &cdf_[num_tokens_ + 1], u) - &cdf_[0];
  141. float norm = 1.0f / sum_;
  142. r.token_id = token_ids_[n];
  143. r.width = (cdf_[n] - cdf_[n - 1]) * norm;
  144. r.start = (cdf_[n - 1]) * norm;
  145. r.fraction = (u - cdf_[n - 1]) / (cdf_[n] - cdf_[n - 1]);
  146. return r;
  147. }
  148. float sum_;
  149. float cdf_[size + 2];
  150. int token_ids_[size + 2];
  151. int num_tokens_;
  152. DISALLOW_COPY_AND_ASSIGN(DiscreteDistribution);
  153. };
  154. } // namespace marbles
  155. #endif // MARBLES_RANDOM_DISTRIBUTIONS_H_