1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.rng.UniformRandomProvider;
20
21 /**
22 * Sampler for a discrete distribution using an optimised look-up table.
23 *
24 * <ul>
25 * <li>
26 * The method requires 30-bit integer probabilities that sum to 2<sup>30</sup> as described
27 * in George Marsaglia, Wai Wan Tsang, Jingbo Wang (2004) Fast Generation of Discrete
28 * Random Variables. Journal of Statistical Software. Vol. 11, Issue. 3, pp. 1-11.
29 * </li>
30 * </ul>
31 *
32 * <p>Sampling uses 1 call to {@link UniformRandomProvider#nextInt()}.</p>
33 *
34 * <p>Memory requirements depend on the maximum number of possible sample values, {@code n},
35 * and the values for the probabilities. Storage is optimised for {@code n}. The worst case
36 * scenario is a uniform distribution of the maximum sample size. This is capped at 0.06MB for
37 * {@code n <= } 2<sup>8</sup>, 17.0MB for {@code n <= } 2<sup>16</sup>, and 4.3GB for
38 * {@code n <=} 2<sup>30</sup>. Realistic requirements will be in the kB range.</p>
39 *
40 * <p>The sampler supports the following distributions:</p>
41 *
42 * <ul>
43 * <li>Enumerated distribution (probabilities must be provided for each sample)</li>
44 * <li>Poisson distribution up to {@code mean = 1024}</li>
45 * <li>Binomial distribution up to {@code trials = 65535}</li>
46 * </ul>
47 *
48 * @see <a href="https://dx.doi.org/10.18637/jss.v011.i03">Margsglia, et al (2004) JSS Vol.
49 * 11, Issue 3</a>
50 * @since 1.3
51 */
52 public final class MarsagliaTsangWangDiscreteSampler {
53 /** The value 2<sup>8</sup> as an {@code int}. */
54 private static final int INT_8 = 1 << 8;
55 /** The value 2<sup>16</sup> as an {@code int}. */
56 private static final int INT_16 = 1 << 16;
57 /** The value 2<sup>30</sup> as an {@code int}. */
58 private static final int INT_30 = 1 << 30;
59 /** The value 2<sup>31</sup> as a {@code double}. */
60 private static final double DOUBLE_31 = 1L << 31;
61
62 // =========================================================================
63 // Implementation note:
64 //
65 // This sampler uses prepared look-up tables that are searched using a single
66 // random int variate. The look-up tables contain the sample value. The tables
67 // are constructed using probabilities that sum to 2^30. The original paper
68 // by Marsaglia, et al (2004) describes the use of 5, 3, or 2 look-up tables
69 // indexed using digits of base 2^6, 2^10 or 2^15. Currently only base 64 (2^6)
70 // is supported using 5 look-up tables.
71 //
72 // The implementations use 8, 16 or 32 bit storage tables to support different
73 // distribution sizes with optimal storage. Separate class implementations of
74 // the same algorithm allow array storage to be accessed directly from 1D tables.
75 // This provides a performance gain over using: abstracted storage accessed via
76 // an interface; or a single 2D table.
77 //
78 // To allow the optimal implementation to be chosen the sampler is created
79 // using factory methods. The sampler supports any probability distribution
80 // when provided via an array of probabilities and the Poisson and Binomial
81 // distributions for a restricted set of parameters. The restrictions are
82 // imposed by the requirement to compute the entire probability distribution
83 // from the controlling parameter(s) using a recursive method. Factory
84 // constructors return a SharedStateDiscreteSampler instance. Each distribution
85 // type is contained in an inner class.
86 // =========================================================================
87
88 /**
89 * The base class for Marsaglia-Tsang-Wang samplers.
90 */
91 private abstract static class AbstractMarsagliaTsangWangDiscreteSampler
92 implements SharedStateDiscreteSampler {
93 /** Underlying source of randomness. */
94 protected final UniformRandomProvider rng;
95
96 /** The name of the distribution. */
97 private final String distributionName;
98
99 /**
100 * @param rng Generator of uniformly distributed random numbers.
101 * @param distributionName Distribution name.
102 */
103 AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
104 String distributionName) {
105 this.rng = rng;
106 this.distributionName = distributionName;
107 }
108
109 /**
110 * @param rng Generator of uniformly distributed random numbers.
111 * @param source Source to copy.
112 */
113 AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
114 AbstractMarsagliaTsangWangDiscreteSampler source) {
115 this.rng = rng;
116 this.distributionName = source.distributionName;
117 }
118
119 /** {@inheritDoc} */
120 @Override
121 public String toString() {
122 return "Marsaglia Tsang Wang " + distributionName + " deviate [" + rng.toString() + "]";
123 }
124 }
125
126 /**
127 * An implementation for the sample algorithm based on the decomposition of the
128 * index in the range {@code [0,2^30)} into 5 base-64 digits with 8-bit backing storage.
129 */
130 private static final class MarsagliaTsangWangBase64Int8DiscreteSampler
131 extends AbstractMarsagliaTsangWangDiscreteSampler {
132 /** The mask to convert a {@code byte} to an unsigned 8-bit integer. */
133 private static final int MASK = 0xff;
134
135 /** Limit for look-up table 1. */
136 private final int t1;
137 /** Limit for look-up table 2. */
138 private final int t2;
139 /** Limit for look-up table 3. */
140 private final int t3;
141 /** Limit for look-up table 4. */
142 private final int t4;
143
144 /** Look-up table table1. */
145 private final byte[] table1;
146 /** Look-up table table2. */
147 private final byte[] table2;
148 /** Look-up table table3. */
149 private final byte[] table3;
150 /** Look-up table table4. */
151 private final byte[] table4;
152 /** Look-up table table5. */
153 private final byte[] table5;
154
155 /**
156 * @param rng Generator of uniformly distributed random numbers.
157 * @param distributionName Distribution name.
158 * @param prob The probabilities.
159 * @param offset The offset (must be positive).
160 */
161 MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng,
162 String distributionName,
163 int[] prob,
164 int offset) {
165 super(rng, distributionName);
166
167 // Get table sizes for each base-64 digit
168 int n1 = 0;
169 int n2 = 0;
170 int n3 = 0;
171 int n4 = 0;
172 int n5 = 0;
173 for (final int m : prob) {
174 n1 += getBase64Digit(m, 1);
175 n2 += getBase64Digit(m, 2);
176 n3 += getBase64Digit(m, 3);
177 n4 += getBase64Digit(m, 4);
178 n5 += getBase64Digit(m, 5);
179 }
180
181 table1 = new byte[n1];
182 table2 = new byte[n2];
183 table3 = new byte[n3];
184 table4 = new byte[n4];
185 table5 = new byte[n5];
186
187 // Compute offsets
188 t1 = n1 << 24;
189 t2 = t1 + (n2 << 18);
190 t3 = t2 + (n3 << 12);
191 t4 = t3 + (n4 << 6);
192 n1 = 0;
193 n2 = 0;
194 n3 = 0;
195 n4 = 0;
196 n5 = 0;
197
198 // Fill tables
199 for (int i = 0; i < prob.length; i++) {
200 final int m = prob[i];
201 // Primitive type conversion will extract lower 8 bits
202 final byte k = (byte) (i + offset);
203 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
204 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
205 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
206 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
207 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
208 }
209 }
210
211 /**
212 * @param rng Generator of uniformly distributed random numbers.
213 * @param source Source to copy.
214 */
215 private MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng,
216 MarsagliaTsangWangBase64Int8DiscreteSampler source) {
217 super(rng, source);
218 t1 = source.t1;
219 t2 = source.t2;
220 t3 = source.t3;
221 t4 = source.t4;
222 table1 = source.table1;
223 table2 = source.table2;
224 table3 = source.table3;
225 table4 = source.table4;
226 table5 = source.table5;
227 }
228
229 /**
230 * Fill the table with the value.
231 *
232 * @param table Table.
233 * @param from Lower bound index (inclusive)
234 * @param to Upper bound index (exclusive)
235 * @param value Value.
236 * @return the upper bound index
237 */
238 private static int fill(byte[] table, int from, int to, byte value) {
239 for (int i = from; i < to; i++) {
240 table[i] = value;
241 }
242 return to;
243 }
244
245 @Override
246 public int sample() {
247 final int j = rng.nextInt() >>> 2;
248 if (j < t1) {
249 return table1[j >>> 24] & MASK;
250 }
251 if (j < t2) {
252 return table2[(j - t1) >>> 18] & MASK;
253 }
254 if (j < t3) {
255 return table3[(j - t2) >>> 12] & MASK;
256 }
257 if (j < t4) {
258 return table4[(j - t3) >>> 6] & MASK;
259 }
260 // Note the tables are filled on the assumption that the sum of the probabilities.
261 // is >=2^30. If this is not true then the final table table5 will be smaller by the
262 // difference. So the tables *must* be constructed correctly.
263 return table5[j - t4] & MASK;
264 }
265
266 @Override
267 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
268 return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, this);
269 }
270 }
271
272 /**
273 * An implementation for the sample algorithm based on the decomposition of the
274 * index in the range {@code [0,2^30)} into 5 base-64 digits with 16-bit backing storage.
275 */
276 private static final class MarsagliaTsangWangBase64Int16DiscreteSampler
277 extends AbstractMarsagliaTsangWangDiscreteSampler {
278 /** The mask to convert a {@code byte} to an unsigned 16-bit integer. */
279 private static final int MASK = 0xffff;
280
281 /** Limit for look-up table 1. */
282 private final int t1;
283 /** Limit for look-up table 2. */
284 private final int t2;
285 /** Limit for look-up table 3. */
286 private final int t3;
287 /** Limit for look-up table 4. */
288 private final int t4;
289
290 /** Look-up table table1. */
291 private final short[] table1;
292 /** Look-up table table2. */
293 private final short[] table2;
294 /** Look-up table table3. */
295 private final short[] table3;
296 /** Look-up table table4. */
297 private final short[] table4;
298 /** Look-up table table5. */
299 private final short[] table5;
300
301 /**
302 * @param rng Generator of uniformly distributed random numbers.
303 * @param distributionName Distribution name.
304 * @param prob The probabilities.
305 * @param offset The offset (must be positive).
306 */
307 MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng,
308 String distributionName,
309 int[] prob,
310 int offset) {
311 super(rng, distributionName);
312
313 // Get table sizes for each base-64 digit
314 int n1 = 0;
315 int n2 = 0;
316 int n3 = 0;
317 int n4 = 0;
318 int n5 = 0;
319 for (final int m : prob) {
320 n1 += getBase64Digit(m, 1);
321 n2 += getBase64Digit(m, 2);
322 n3 += getBase64Digit(m, 3);
323 n4 += getBase64Digit(m, 4);
324 n5 += getBase64Digit(m, 5);
325 }
326
327 table1 = new short[n1];
328 table2 = new short[n2];
329 table3 = new short[n3];
330 table4 = new short[n4];
331 table5 = new short[n5];
332
333 // Compute offsets
334 t1 = n1 << 24;
335 t2 = t1 + (n2 << 18);
336 t3 = t2 + (n3 << 12);
337 t4 = t3 + (n4 << 6);
338 n1 = 0;
339 n2 = 0;
340 n3 = 0;
341 n4 = 0;
342 n5 = 0;
343
344 // Fill tables
345 for (int i = 0; i < prob.length; i++) {
346 final int m = prob[i];
347 // Primitive type conversion will extract lower 16 bits
348 final short k = (short) (i + offset);
349 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
350 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
351 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
352 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
353 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
354 }
355 }
356
357 /**
358 * @param rng Generator of uniformly distributed random numbers.
359 * @param source Source to copy.
360 */
361 private MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng,
362 MarsagliaTsangWangBase64Int16DiscreteSampler source) {
363 super(rng, source);
364 t1 = source.t1;
365 t2 = source.t2;
366 t3 = source.t3;
367 t4 = source.t4;
368 table1 = source.table1;
369 table2 = source.table2;
370 table3 = source.table3;
371 table4 = source.table4;
372 table5 = source.table5;
373 }
374
375 /**
376 * Fill the table with the value.
377 *
378 * @param table Table.
379 * @param from Lower bound index (inclusive)
380 * @param to Upper bound index (exclusive)
381 * @param value Value.
382 * @return the upper bound index
383 */
384 private static int fill(short[] table, int from, int to, short value) {
385 for (int i = from; i < to; i++) {
386 table[i] = value;
387 }
388 return to;
389 }
390
391 @Override
392 public int sample() {
393 final int j = rng.nextInt() >>> 2;
394 if (j < t1) {
395 return table1[j >>> 24] & MASK;
396 }
397 if (j < t2) {
398 return table2[(j - t1) >>> 18] & MASK;
399 }
400 if (j < t3) {
401 return table3[(j - t2) >>> 12] & MASK;
402 }
403 if (j < t4) {
404 return table4[(j - t3) >>> 6] & MASK;
405 }
406 // Note the tables are filled on the assumption that the sum of the probabilities.
407 // is >=2^30. If this is not true then the final table table5 will be smaller by the
408 // difference. So the tables *must* be constructed correctly.
409 return table5[j - t4] & MASK;
410 }
411
412 @Override
413 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
414 return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, this);
415 }
416 }
417
418 /**
419 * An implementation for the sample algorithm based on the decomposition of the
420 * index in the range {@code [0,2^30)} into 5 base-64 digits with 32-bit backing storage.
421 */
422 private static final class MarsagliaTsangWangBase64Int32DiscreteSampler
423 extends AbstractMarsagliaTsangWangDiscreteSampler {
424 /** Limit for look-up table 1. */
425 private final int t1;
426 /** Limit for look-up table 2. */
427 private final int t2;
428 /** Limit for look-up table 3. */
429 private final int t3;
430 /** Limit for look-up table 4. */
431 private final int t4;
432
433 /** Look-up table table1. */
434 private final int[] table1;
435 /** Look-up table table2. */
436 private final int[] table2;
437 /** Look-up table table3. */
438 private final int[] table3;
439 /** Look-up table table4. */
440 private final int[] table4;
441 /** Look-up table table5. */
442 private final int[] table5;
443
444 /**
445 * @param rng Generator of uniformly distributed random numbers.
446 * @param distributionName Distribution name.
447 * @param prob The probabilities.
448 * @param offset The offset (must be positive).
449 */
450 MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng,
451 String distributionName,
452 int[] prob,
453 int offset) {
454 super(rng, distributionName);
455
456 // Get table sizes for each base-64 digit
457 int n1 = 0;
458 int n2 = 0;
459 int n3 = 0;
460 int n4 = 0;
461 int n5 = 0;
462 for (final int m : prob) {
463 n1 += getBase64Digit(m, 1);
464 n2 += getBase64Digit(m, 2);
465 n3 += getBase64Digit(m, 3);
466 n4 += getBase64Digit(m, 4);
467 n5 += getBase64Digit(m, 5);
468 }
469
470 table1 = new int[n1];
471 table2 = new int[n2];
472 table3 = new int[n3];
473 table4 = new int[n4];
474 table5 = new int[n5];
475
476 // Compute offsets
477 t1 = n1 << 24;
478 t2 = t1 + (n2 << 18);
479 t3 = t2 + (n3 << 12);
480 t4 = t3 + (n4 << 6);
481 n1 = 0;
482 n2 = 0;
483 n3 = 0;
484 n4 = 0;
485 n5 = 0;
486
487 // Fill tables
488 for (int i = 0; i < prob.length; i++) {
489 final int m = prob[i];
490 final int k = i + offset;
491 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
492 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
493 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
494 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
495 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
496 }
497 }
498
499 /**
500 * @param rng Generator of uniformly distributed random numbers.
501 * @param source Source to copy.
502 */
503 private MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng,
504 MarsagliaTsangWangBase64Int32DiscreteSampler source) {
505 super(rng, source);
506 t1 = source.t1;
507 t2 = source.t2;
508 t3 = source.t3;
509 t4 = source.t4;
510 table1 = source.table1;
511 table2 = source.table2;
512 table3 = source.table3;
513 table4 = source.table4;
514 table5 = source.table5;
515 }
516
517 /**
518 * Fill the table with the value.
519 *
520 * @param table Table.
521 * @param from Lower bound index (inclusive)
522 * @param to Upper bound index (exclusive)
523 * @param value Value.
524 * @return the upper bound index
525 */
526 private static int fill(int[] table, int from, int to, int value) {
527 for (int i = from; i < to; i++) {
528 table[i] = value;
529 }
530 return to;
531 }
532
533 @Override
534 public int sample() {
535 final int j = rng.nextInt() >>> 2;
536 if (j < t1) {
537 return table1[j >>> 24];
538 }
539 if (j < t2) {
540 return table2[(j - t1) >>> 18];
541 }
542 if (j < t3) {
543 return table3[(j - t2) >>> 12];
544 }
545 if (j < t4) {
546 return table4[(j - t3) >>> 6];
547 }
548 // Note the tables are filled on the assumption that the sum of the probabilities.
549 // is >=2^30. If this is not true then the final table table5 will be smaller by the
550 // difference. So the tables *must* be constructed correctly.
551 return table5[j - t4];
552 }
553
554 @Override
555 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
556 return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, this);
557 }
558 }
559
560
561
562 /** Class contains only static methods. */
563 private MarsagliaTsangWangDiscreteSampler() {}
564
565 /**
566 * Gets the k<sup>th</sup> base 64 digit of {@code m}.
567 *
568 * @param m the value m.
569 * @param k the digit.
570 * @return the base 64 digit
571 */
572 private static int getBase64Digit(int m, int k) {
573 return (m >>> (30 - 6 * k)) & 63;
574 }
575
576 /**
577 * Convert the probability to an integer in the range [0,2^30]. This is the numerator of
578 * a fraction with assumed denominator 2<sup>30</sup>.
579 *
580 * @param p Probability.
581 * @return the fraction numerator
582 */
583 private static int toUnsignedInt30(double p) {
584 return (int) (p * INT_30 + 0.5);
585 }
586
587 /**
588 * Create a new instance for probabilities {@code p(i)} where the sample value {@code x} is
589 * {@code i + offset}.
590 *
591 * <p>The sum of the probabilities must be {@code >=} 2<sup>30</sup>. Only the
592 * values for cumulative probability up to 2<sup>30</sup> will be sampled.</p>
593 *
594 * @param rng Generator of uniformly distributed random numbers.
595 * @param distributionName Distribution name.
596 * @param prob The probabilities.
597 * @param offset The offset (must be positive).
598 * @return Sampler.
599 */
600 private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
601 String distributionName,
602 int[] prob,
603 int offset) {
604 // Note: No argument checks for private method.
605
606 // Choose implementation based on the maximum index
607 final int maxIndex = prob.length + offset - 1;
608 if (maxIndex < INT_8) {
609 return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, distributionName, prob, offset);
610 }
611 if (maxIndex < INT_16) {
612 return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, distributionName, prob, offset);
613 }
614 return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, distributionName, prob, offset);
615 }
616
617 // =========================================================================
618 // The following public classes provide factory methods to construct a sampler for:
619 // - Enumerated probability distribution (from provided double[] probabilities)
620 // - Poisson distribution for mean <= 1024
621 // - Binomial distribution for trials <= 65535
622 // =========================================================================
623
624 /**
625 * Create a sampler for an enumerated distribution of {@code n} values each with an
626 * associated probability.
627 * The samples corresponding to each probability are assumed to be a natural sequence
628 * starting at zero.
629 */
630 public static final class Enumerated {
631 /** The name of the enumerated probability distribution. */
632 private static final String ENUMERATED_NAME = "Enumerated";
633
634 /** Class contains only static methods. */
635 private Enumerated() {}
636
637 /**
638 * Creates a sampler for an enumerated distribution of {@code n} values each with an
639 * associated probability.
640 *
641 * <p>The probabilities will be normalised using their sum. The only requirement
642 * is the sum is positive.</p>
643 *
644 * <p>The sum of the probabilities is normalised to 2<sup>30</sup>. Note that
645 * probabilities are adjusted to the nearest 2<sup>-30</sup> due to round-off during
646 * the normalisation conversion. Consequently any probability less than 2<sup>-31</sup>
647 * will not be observed in samples.</p>
648 *
649 * @param rng Generator of uniformly distributed random numbers.
650 * @param probabilities The list of probabilities.
651 * @return Sampler.
652 * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
653 * probability is negative, infinite or {@code NaN}, or the sum of all
654 * probabilities is not strictly positive.
655 */
656 public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
657 double[] probabilities) {
658 return createSampler(rng, ENUMERATED_NAME, normaliseProbabilities(probabilities), 0);
659 }
660
661 /**
662 * Normalise the probabilities to integers that sum to 2<sup>30</sup>.
663 *
664 * @param probabilities The list of probabilities.
665 * @return the normalised probabilities.
666 * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
667 * probability is negative, infinite or {@code NaN}, or the sum of all
668 * probabilities is not strictly positive.
669 */
670 private static int[] normaliseProbabilities(double[] probabilities) {
671 final double sumProb = InternalUtils.validateProbabilities(probabilities);
672
673 // Compute the normalisation: 2^30 / sum
674 final double normalisation = INT_30 / sumProb;
675 final int[] prob = new int[probabilities.length];
676 int sum = 0;
677 int max = 0;
678 int mode = 0;
679 for (int i = 0; i < prob.length; i++) {
680 // Add 0.5 for rounding
681 final int p = (int) (probabilities[i] * normalisation + 0.5);
682 sum += p;
683 // Find the mode (maximum probability)
684 if (max < p) {
685 max = p;
686 mode = i;
687 }
688 prob[i] = p;
689 }
690
691 // The sum must be >= 2^30.
692 // Here just compensate the difference onto the highest probability.
693 prob[mode] += INT_30 - sum;
694
695 return prob;
696 }
697 }
698
699 /**
700 * Create a sampler for the Poisson distribution.
701 */
702 public static final class Poisson {
703 /** The name of the Poisson distribution. */
704 private static final String POISSON_NAME = "Poisson";
705
706 /**
707 * Upper bound on the mean for the Poisson distribution.
708 *
709 * <p>The original source code provided in Marsaglia, et al (2004) has no explicit
710 * limit but the code fails at mean {@code >= 1941} as the transform to compute p(x=mode)
711 * produces infinity. Use a conservative limit of 1024.</p>
712 */
713
714 private static final double MAX_MEAN = 1024;
715 /**
716 * The threshold for the mean of the Poisson distribution to switch the method used
717 * to compute the probabilities. This is taken from the example software provided by
718 * Marsaglia, et al (2004).
719 */
720 private static final double MEAN_THRESHOLD = 21.4;
721
722 /** Class contains only static methods. */
723 private Poisson() {}
724
725 /**
726 * Creates a sampler for the Poisson distribution.
727 *
728 * <p>Any probability less than 2<sup>-31</sup> will not be observed in samples.</p>
729 *
730 * <p>Storage requirements depend on the tabulated probability values. Example storage
731 * requirements are listed below.</p>
732 *
733 * <pre>
734 * mean table size kB
735 * 0.25 882 0.88
736 * 0.5 1135 1.14
737 * 1 1200 1.20
738 * 2 1451 1.45
739 * 4 1955 1.96
740 * 8 2961 2.96
741 * 16 4410 4.41
742 * 32 6115 6.11
743 * 64 8499 8.50
744 * 128 11528 11.53
745 * 256 15935 31.87
746 * 512 20912 41.82
747 * 1024 30614 61.23
748 * </pre>
749 *
750 * <p>Note: Storage changes to 2 bytes per index between {@code mean=128} and {@code mean=256}.</p>
751 *
752 * @param rng Generator of uniformly distributed random numbers.
753 * @param mean Mean.
754 * @return Sampler.
755 * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}.
756 */
757 public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
758 double mean) {
759 validatePoissonDistributionParameters(mean);
760
761 // Create the distribution either from X=0 or from X=mode when the mean is high.
762 return mean < MEAN_THRESHOLD ?
763 createPoissonDistributionFromX0(rng, mean) :
764 createPoissonDistributionFromXMode(rng, mean);
765 }
766
767 /**
768 * Validate the Poisson distribution parameters.
769 *
770 * @param mean Mean.
771 * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}.
772 */
773 private static void validatePoissonDistributionParameters(double mean) {
774 InternalUtils.requireStrictlyPositive(mean, "mean");
775 if (mean > MAX_MEAN) {
776 throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN);
777 }
778 }
779
780 /**
781 * Creates the Poisson distribution by computing probabilities recursively from {@code X=0}.
782 *
783 * @param rng Generator of uniformly distributed random numbers.
784 * @param mean Mean.
785 * @return Sampler.
786 */
787 private static SharedStateDiscreteSampler createPoissonDistributionFromX0(
788 UniformRandomProvider rng, double mean) {
789 final double p0 = Math.exp(-mean);
790
791 // Recursive update of Poisson probability until the value is too small
792 // p(x + 1) = p(x) * mean / (x + 1)
793 double p = p0;
794 int i = 1;
795 while (p * DOUBLE_31 >= 1) {
796 p *= mean / i++;
797 }
798
799 // Probabilities are 30-bit integers, assumed denominator 2^30
800 final int size = i - 1;
801 final int[] prob = new int[size];
802
803 p = p0;
804 prob[0] = toUnsignedInt30(p);
805 // The sum must exceed 2^30. In edges cases this is false due to round-off.
806 int sum = prob[0];
807 for (i = 1; i < prob.length; i++) {
808 p *= mean / i;
809 prob[i] = toUnsignedInt30(p);
810 sum += prob[i];
811 }
812
813 // If the sum is < 2^30 add the remaining sum to the mode (floor(mean)).
814 prob[(int) mean] += Math.max(0, INT_30 - sum);
815
816 // Note: offset = 0
817 return createSampler(rng, POISSON_NAME, prob, 0);
818 }
819
820 /**
821 * Creates the Poisson distribution by computing probabilities recursively upward and downward
822 * from {@code X=mode}, the location of the largest p-value.
823 *
824 * @param rng Generator of uniformly distributed random numbers.
825 * @param mean Mean.
826 * @return Sampler.
827 */
828 private static SharedStateDiscreteSampler createPoissonDistributionFromXMode(
829 UniformRandomProvider rng, double mean) {
830 // If mean >= 21.4, generate from largest p-value up, then largest down.
831 // The largest p-value will be at the mode (floor(mean)).
832
833 // Find p(x=mode)
834 final int mode = (int) mean;
835 // This transform is stable until mean >= 1941 where p will result in Infinity
836 // before the divisor i is large enough to start reducing the product (i.e. i > c).
837 final double c = mean * Math.exp(-mean / mode);
838 double p = 1.0;
839 for (int i = 1; i <= mode; i++) {
840 p *= c / i;
841 }
842 final double pMode = p;
843
844 // Find the upper limit using recursive computation of the p-value.
845 // Note this will exit when i overflows to negative so no check on the range
846 int i = mode + 1;
847 while (p * DOUBLE_31 >= 1) {
848 p *= mean / i++;
849 }
850 final int last = i - 2;
851
852 // Find the lower limit using recursive computation of the p-value.
853 p = pMode;
854 int j = -1;
855 for (i = mode - 1; i >= 0; i--) {
856 p *= (i + 1) / mean;
857 if (p * DOUBLE_31 < 1) {
858 j = i;
859 break;
860 }
861 }
862
863 // Probabilities are 30-bit integers, assumed denominator 2^30.
864 // This is the minimum sample value: prob[x - offset] = p(x)
865 final int offset = j + 1;
866 final int size = last - offset + 1;
867 final int[] prob = new int[size];
868
869 p = pMode;
870 prob[mode - offset] = toUnsignedInt30(p);
871 // The sum must exceed 2^30. In edges cases this is false due to round-off.
872 int sum = prob[mode - offset];
873 // From mode to upper limit
874 for (i = mode + 1; i <= last; i++) {
875 p *= mean / i;
876 prob[i - offset] = toUnsignedInt30(p);
877 sum += prob[i - offset];
878 }
879 // From mode to lower limit
880 p = pMode;
881 for (i = mode - 1; i >= offset; i--) {
882 p *= (i + 1) / mean;
883 prob[i - offset] = toUnsignedInt30(p);
884 sum += prob[i - offset];
885 }
886
887 // If the sum is < 2^30 add the remaining sum to the mode.
888 // If above 2^30 then the effect is truncation of the long tail of the distribution.
889 prob[mode - offset] += Math.max(0, INT_30 - sum);
890
891 return createSampler(rng, POISSON_NAME, prob, offset);
892 }
893 }
894
895 /**
896 * Create a sampler for the Binomial distribution.
897 */
898 public static final class Binomial {
899 /** The name of the Binomial distribution. */
900 private static final String BINOMIAL_NAME = "Binomial";
901
902 /**
903 * Return a fixed result for the Binomial distribution. This is a special class to handle
904 * an edge case of probability of success equal to 0 or 1.
905 */
906 private static final class MarsagliaTsangWangFixedResultBinomialSampler
907 extends AbstractMarsagliaTsangWangDiscreteSampler {
908 /** The result. */
909 private final int result;
910
911 /**
912 * @param result Result.
913 */
914 MarsagliaTsangWangFixedResultBinomialSampler(int result) {
915 super(null, BINOMIAL_NAME);
916 this.result = result;
917 }
918
919 @Override
920 public int sample() {
921 return result;
922 }
923
924 @Override
925 public String toString() {
926 return BINOMIAL_NAME + " deviate";
927 }
928
929 @Override
930 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
931 // No shared state
932 return this;
933 }
934 }
935
936 /**
937 * Return an inversion result for the Binomial distribution. This assumes the
938 * following:
939 *
940 * <pre>
941 * Binomial(n, p) = 1 - Binomial(n, 1 - p)
942 * </pre>
943 */
944 private static final class MarsagliaTsangWangInversionBinomialSampler
945 extends AbstractMarsagliaTsangWangDiscreteSampler {
946 /** The number of trials. */
947 private final int trials;
948 /** The Binomial distribution sampler. */
949 private final SharedStateDiscreteSampler sampler;
950
951 /**
952 * @param trials Number of trials.
953 * @param sampler Binomial distribution sampler.
954 */
955 MarsagliaTsangWangInversionBinomialSampler(int trials,
956 SharedStateDiscreteSampler sampler) {
957 super(null, BINOMIAL_NAME);
958 this.trials = trials;
959 this.sampler = sampler;
960 }
961
962 @Override
963 public int sample() {
964 return trials - sampler.sample();
965 }
966
967 @Override
968 public String toString() {
969 return sampler.toString();
970 }
971
972 @Override
973 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
974 return new MarsagliaTsangWangInversionBinomialSampler(this.trials,
975 this.sampler.withUniformRandomProvider(rng));
976 }
977 }
978
979 /** Class contains only static methods. */
980 private Binomial() {}
981
982 /**
983 * Creates a sampler for the Binomial distribution.
984 *
985 * <p>Any probability less than 2<sup>-31</sup> will not be observed in samples.</p>
986 *
987 * <p>Storage requirements depend on the tabulated probability values. Example storage
988 * requirements are listed below (in kB).</p>
989 *
990 * <pre>
991 * p
992 * trials 0.5 0.1 0.01 0.001
993 * 4 0.06 0.63 0.44 0.44
994 * 16 0.69 1.14 0.76 0.44
995 * 64 4.73 2.40 1.14 0.51
996 * 256 8.63 5.17 1.89 0.82
997 * 1024 31.12 9.45 3.34 0.89
998 * </pre>
999 *
1000 * <p>The method requires that the Binomial distribution probability at {@code x=0} can be computed.
1001 * This will fail when {@code (1 - p)^trials == 0} which requires {@code trials} to be large
1002 * and/or {@code p} to be small. In this case an exception is raised.</p>
1003 *
1004 * @param rng Generator of uniformly distributed random numbers.
1005 * @param trials Number of trials.
1006 * @param probabilityOfSuccess Probability of success (p).
1007 * @return Sampler.
1008 * @throws IllegalArgumentException if {@code trials < 0} or {@code trials >= 2^16},
1009 * {@code p} is not in the range {@code [0-1]}, or the probability distribution cannot
1010 * be computed.
1011 */
1012 public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
1013 int trials,
1014 double probabilityOfSuccess) {
1015 validateBinomialDistributionParameters(trials, probabilityOfSuccess);
1016
1017 // Handle edge cases
1018 if (probabilityOfSuccess == 0) {
1019 return new MarsagliaTsangWangFixedResultBinomialSampler(0);
1020 }
1021 if (probabilityOfSuccess == 1) {
1022 return new MarsagliaTsangWangFixedResultBinomialSampler(trials);
1023 }
1024
1025 // Check the supported size.
1026 if (trials >= INT_16) {
1027 throw new IllegalArgumentException("Unsupported number of trials: " + trials);
1028 }
1029
1030 return createBinomialDistributionSampler(rng, trials, probabilityOfSuccess);
1031 }
1032
1033 /**
1034 * Validate the Binomial distribution parameters.
1035 *
1036 * @param trials Number of trials.
1037 * @param probabilityOfSuccess Probability of success (p).
1038 * @throws IllegalArgumentException if {@code trials < 0} or
1039 * {@code p} is not in the range {@code [0-1]}
1040 */
1041 private static void validateBinomialDistributionParameters(int trials, double probabilityOfSuccess) {
1042 if (trials < 0) {
1043 throw new IllegalArgumentException("Trials is not positive: " + trials);
1044 }
1045 InternalUtils.requireRangeClosed(0, 1, probabilityOfSuccess, "probability of success");
1046 }
1047
1048 /**
1049 * Creates the Binomial distribution sampler.
1050 *
1051 * <p>This assumes the parameters for the distribution are valid. The method
1052 * will only fail if the initial probability for {@code X=0} is zero.</p>
1053 *
1054 * @param rng Generator of uniformly distributed random numbers.
1055 * @param trials Number of trials.
1056 * @param probabilityOfSuccess Probability of success (p).
1057 * @return Sampler.
1058 * @throws IllegalArgumentException if the probability distribution cannot be
1059 * computed.
1060 */
1061 private static SharedStateDiscreteSampler createBinomialDistributionSampler(
1062 UniformRandomProvider rng, int trials, double probabilityOfSuccess) {
1063
1064 // The maximum supported value for Math.exp is approximately -744.
1065 // This occurs when trials is large and p is close to 1.
1066 // Handle this by using an inversion: generate j=Binomial(n,1-p), return n-j
1067 final boolean useInversion = probabilityOfSuccess > 0.5;
1068 final double p = useInversion ? 1 - probabilityOfSuccess : probabilityOfSuccess;
1069
1070 // Check if the distribution can be computed
1071 final double p0 = Math.exp(trials * Math.log(1 - p));
1072 if (p0 < Double.MIN_VALUE) {
1073 throw new IllegalArgumentException("Unable to compute distribution");
1074 }
1075
1076 // First find size of probability array
1077 double t = p0;
1078 final double h = p / (1 - p);
1079 // Find first probability above the threshold of 2^-31
1080 int begin = 0;
1081 if (t * DOUBLE_31 < 1) {
1082 // Somewhere after p(0)
1083 // Note:
1084 // If this loop is entered p(0) is < 2^-31.
1085 // This has been tested at the extreme for p(0)=Double.MIN_VALUE and either
1086 // p=0.5 or trials=2^16-1 and does not fail to find the beginning.
1087 for (int i = 1; i <= trials; i++) {
1088 t *= (trials + 1 - i) * h / i;
1089 if (t * DOUBLE_31 >= 1) {
1090 begin = i;
1091 break;
1092 }
1093 }
1094 }
1095 // Find last probability
1096 int end = trials;
1097 for (int i = begin + 1; i <= trials; i++) {
1098 t *= (trials + 1 - i) * h / i;
1099 if (t * DOUBLE_31 < 1) {
1100 end = i - 1;
1101 break;
1102 }
1103 }
1104
1105 return createBinomialDistributionSamplerFromRange(rng, trials, p, useInversion,
1106 p0, begin, end);
1107 }
1108
1109 /**
1110 * Creates the Binomial distribution sampler using only the probability values for {@code X}
1111 * between the begin and the end (inclusive).
1112 *
1113 * @param rng Generator of uniformly distributed random numbers.
1114 * @param trials Number of trials.
1115 * @param p Probability of success (p).
1116 * @param useInversion Set to {@code true} if the probability was inverted.
1117 * @param p0 Probability at {@code X=0}
1118 * @param begin Begin value {@code X} for the distribution.
1119 * @param end End value {@code X} for the distribution.
1120 * @return Sampler.
1121 */
1122 private static SharedStateDiscreteSampler createBinomialDistributionSamplerFromRange(
1123 UniformRandomProvider rng, int trials, double p,
1124 boolean useInversion, double p0, int begin, int end) {
1125
1126 // Assign probability values as 30-bit integers
1127 final int size = end - begin + 1;
1128 final int[] prob = new int[size];
1129 double t = p0;
1130 final double h = p / (1 - p);
1131 for (int i = 1; i <= begin; i++) {
1132 t *= (trials + 1 - i) * h / i;
1133 }
1134 int sum = toUnsignedInt30(t);
1135 prob[0] = sum;
1136 for (int i = begin + 1; i <= end; i++) {
1137 t *= (trials + 1 - i) * h / i;
1138 prob[i - begin] = toUnsignedInt30(t);
1139 sum += prob[i - begin];
1140 }
1141
1142 // If the sum is < 2^30 add the remaining sum to the mode (floor((n+1)p))).
1143 // If above 2^30 then the effect is truncation of the long tail of the distribution.
1144 final int mode = (int) ((trials + 1) * p) - begin;
1145 prob[mode] += Math.max(0, INT_30 - sum);
1146
1147 final SharedStateDiscreteSampler sampler = createSampler(rng, BINOMIAL_NAME, prob, begin);
1148
1149 // Check if an inversion was made
1150 return useInversion ?
1151 new MarsagliaTsangWangInversionBinomialSampler(trials, sampler) :
1152 sampler;
1153 }
1154 }
1155 }