NormalInt.java
package org.loudouncodes.randkit.discrete;
import java.util.random.RandomGenerator;
import org.loudouncodes.randkit.api.DiscreteDistribution;
import org.loudouncodes.randkit.api.DistributionSupport;
import org.loudouncodes.randkit.util.Randoms;
/**
* A <em>discrete normal</em> (rounded-normal) distribution on the integers.
*
* <p>This distribution is constructed by sampling a continuous normal {@code X ~ Normal(mean,
* sigma^2)} and returning the integer {@code Y = round(X)} (nearest-even rounding via {@link
* Math#rint(double)}). It therefore places probability mass on all integers (unless truncated).
*
* <h2>Probability functions</h2>
*
* <p>For the <strong>untruncated</strong> rounded-normal, the probability mass function (PMF) and
* cumulative distribution function (CDF) are exact:
*
* <pre>{@code
* pmf(k) = Φ((k + 0.5 - mean)/sigma) - Φ((k - 0.5 - mean)/sigma)
* cdf(k) = Φ((k + 0.5 - mean)/sigma)
* }</pre>
*
* where {@code Φ} is the standard normal CDF.
*
* <p>For the <strong>truncated</strong> variant on {@code [lower, upper]} (inclusive), mass is
* re-normalized to that interval:
*
* <pre>{@code
* Z = Φ((upper + 0.5 - mean)/sigma) - Φ((lower - 0.5 - mean)/sigma)
* pmf_trunc(k) = pmf(k) / Z for k in [lower, upper], else 0
* cdf_trunc(k) = (Φ((k + 0.5 - mean)/sigma) - Φ((lower - 0.5 - mean)/sigma)) / Z
* }</pre>
*
* <h2>Sampling</h2>
*
* <p>Sampling uses the Marsaglia polar (Box–Muller) method to draw a standard normal, then
* scales/shifts and applies nearest-even rounding. When truncated, out-of-range draws are rejected
* and re-sampled (efficient unless the window is deep in the tails).
*
* <h2>Numerics</h2>
*
* <ul>
* <li>{@code Φ} is computed via an {@code erf} approximation (Abramowitz & Stegun 7.1.26),
* with max absolute error ≈ 1.5e−7.
* <li>Untruncated mean/variance are computed by summing the PMF over a wide {@code ±8σ} window
* (extended until the remaining tail mass is negligible).
* </ul>
*
* <h2>Determinism & threading</h2>
*
* <p>Given the same seed, parameters, and JDK RNG algorithm, sequences are repeatable. Instances
* are not synchronized; prefer one instance per thread or supply thread-local RNGs.
*
* <h2>Examples</h2>
*
* <pre>{@code
* // Untruncated discrete normal centered near 2, σ = 1.5
* var d1 = new NormalInt(2.0, 1.5);
* int x = d1.sample();
*
* // Seeded and truncated to [-3, +3]
* var d2 = new NormalInt(1234L, 0.0, 2.0, -3, 3);
* double p0 = d2.pmf(0); // probability at 0
* double F2 = d2.cdf(2); // P(Y <= 2)
* }</pre>
*
* @since 0.1.0
*/
public final class NormalInt implements DiscreteDistribution {
/** RNG supplying uniform bits; provided or created via {@link Randoms}. */
private final RandomGenerator rng;
/** Mean {@code μ} of the underlying continuous normal. */
private final double meanParam;
/** Standard deviation {@code σ} of the underlying continuous normal (must be {@code > 0}). */
private final double sigma;
// Truncation metadata (active when 'truncated' is true).
private final boolean truncated;
private final int lower;
private final int upper;
private final double normZ; // normalization constant for truncated mass
private final boolean degenerate; // when truncation window has ~zero mass
private final int degenerateValue;
// Cached moments reported by the API.
private final double meanCached;
private final double varianceCached;
// Box–Muller (polar) spare sample caching.
private boolean haveSpare = false;
private double spare = 0.0;
private static final double INV_SQRT2 = 1.0 / Math.sqrt(2.0);
// ---------- Constructors (untruncated) ----------
/**
* Creates an untruncated discrete normal using the library's default RNG.
*
* @param mean the mean {@code μ} of the underlying continuous normal (finite)
* @param sigma the standard deviation {@code σ} of the underlying normal; must be {@code > 0}
* @throws IllegalArgumentException if {@code mean} is not finite or {@code sigma <= 0} or not
* finite
*/
public NormalInt(double mean, double sigma) {
this(Randoms.defaultGenerator(), mean, sigma);
}
/**
* Creates an untruncated discrete normal with a deterministic RNG built from {@code seed}.
*
* @param seed RNG seed for reproducible sampling
* @param mean the mean {@code μ} of the underlying continuous normal (finite)
* @param sigma the standard deviation {@code σ} of the underlying normal; must be {@code > 0}
* @throws IllegalArgumentException if {@code mean} is not finite or {@code sigma <= 0} or not
* finite
*/
public NormalInt(long seed, double mean, double sigma) {
this(Randoms.seeded(seed), mean, sigma);
}
/**
* Creates an untruncated discrete normal with a caller-supplied RNG.
*
* @param rng the random generator (must not be {@code null})
* @param mean the mean {@code μ} of the underlying continuous normal (finite)
* @param sigma the standard deviation {@code σ} of the underlying normal; must be {@code > 0}
* @throws NullPointerException if {@code rng} is {@code null}
* @throws IllegalArgumentException if {@code mean} is not finite or {@code sigma <= 0} or not
* finite
*/
public NormalInt(RandomGenerator rng, double mean, double sigma) {
validateParams(mean, sigma);
if (rng == null) throw new NullPointerException("rng must not be null");
this.rng = rng;
this.meanParam = mean;
this.sigma = sigma;
this.truncated = false;
this.lower = Integer.MIN_VALUE;
this.upper = Integer.MAX_VALUE;
this.normZ = 1.0;
this.degenerate = false;
this.degenerateValue = (int) Math.rint(meanParam);
double[] m = computeUntruncatedMoments();
this.meanCached = m[0];
this.varianceCached = m[1];
}
// ---------- Constructors (truncated) ----------
/**
* Creates a truncated discrete normal on the closed interval {@code [lower, upper]} using the
* library's default RNG. Mass is re-normalized to the interval.
*
* @param mean the mean {@code μ} of the underlying continuous normal (finite)
* @param sigma the standard deviation {@code σ} of the underlying normal; must be {@code > 0}
* @param lower inclusive lower bound for the integer outcomes
* @param upper inclusive upper bound for the integer outcomes; must be {@code >= lower}
* @throws IllegalArgumentException if parameters are invalid or {@code lower > upper}
*/
public NormalInt(double mean, double sigma, int lower, int upper) {
this(Randoms.defaultGenerator(), mean, sigma, lower, upper);
}
/**
* Creates a truncated discrete normal on {@code [lower, upper]} with a deterministic RNG.
*
* @param seed RNG seed for reproducible sampling
* @param mean the mean {@code μ} of the underlying continuous normal (finite)
* @param sigma the standard deviation {@code σ} of the underlying normal; must be {@code > 0}
* @param lower inclusive lower bound for the integer outcomes
* @param upper inclusive upper bound for the integer outcomes; must be {@code >= lower}
* @throws IllegalArgumentException if parameters are invalid or {@code lower > upper}
*/
public NormalInt(long seed, double mean, double sigma, int lower, int upper) {
this(Randoms.seeded(seed), mean, sigma, lower, upper);
}
/**
* Creates a truncated discrete normal on {@code [lower, upper]} with a caller-supplied RNG.
*
* @param rng the random generator (must not be {@code null})
* @param mean the mean {@code μ} of the underlying continuous normal (finite)
* @param sigma the standard deviation {@code σ} of the underlying normal; must be {@code > 0}
* @param lower inclusive lower bound for the integer outcomes
* @param upper inclusive upper bound for the integer outcomes; must be {@code >= lower}
* @throws NullPointerException if {@code rng} is {@code null}
* @throws IllegalArgumentException if parameters are invalid or {@code lower > upper}
*/
public NormalInt(RandomGenerator rng, double mean, double sigma, int lower, int upper) {
validateParams(mean, sigma);
if (rng == null) throw new NullPointerException("rng must not be null");
if (lower > upper) {
throw new IllegalArgumentException(
"lower must be <= upper (got " + lower + " > " + upper + ")");
}
this.rng = rng;
this.meanParam = mean;
this.sigma = sigma;
this.truncated = true;
this.lower = lower;
this.upper = upper;
// Normalization mass over the truncated support:
// Z = P(Y in [lower, upper]) = Φ(upper+0.5) - Φ(lower-0.5)
double zUpper = normalCdf(upper + 0.5);
double zLower = normalCdf(lower - 0.5);
double z = zUpper - zLower;
this.normZ = z;
if (!(z > 0.0)) {
// Window captures (numerically) no mass: become a point-mass at the
// rounded mean, clamped to [lower, upper].
this.degenerate = true;
int k = (int) Math.rint(meanParam);
if (k < lower) k = lower;
if (k > upper) k = upper;
this.degenerateValue = k;
this.meanCached = k;
this.varianceCached = 0.0;
} else {
this.degenerate = false;
this.degenerateValue = (int) Math.rint(meanParam);
double[] mv = computeTruncatedMoments(lower, upper, z);
this.meanCached = mv[0];
this.varianceCached = mv[1];
}
}
// ---------- API (DiscreteDistribution) ----------
/**
* Draws a single sample {@code Y} from this distribution.
*
* <p>For untruncated distributions, this value can be any integer (with quickly decaying tails).
* For truncated distributions, the result is guaranteed to lie in {@code [lower, upper]}.
*
* @return a random integer variate
*/
@Override
public int sample() {
if (degenerate) return degenerateValue;
while (true) {
// Draw a Normal(mean, sigma^2)
double g = meanParam + sigma * nextGaussian();
int y = (int) Math.rint(g); // nearest-even rounding
if (!truncated || (y >= lower && y <= upper)) {
return y;
}
// else reject and resample
}
}
/**
* Returns the probability mass at {@code k}, i.e., {@code P(Y = k)}.
*
* @param k integer at which to evaluate the PMF
* @return the probability mass at {@code k} (zero if truncated and {@code k} is outside the
* bounds)
*/
@Override
public double pmf(int k) {
if (degenerate) {
return (k == degenerateValue) ? 1.0 : 0.0;
}
if (truncated && (k < lower || k > upper)) return 0.0;
double p = pmfUntruncated(k);
return truncated ? p / normZ : p;
}
/**
* Returns the cumulative distribution function at {@code k}, i.e., {@code P(Y ≤ k)}.
*
* @param k integer at which to evaluate the CDF
* @return the cumulative probability at {@code k}
*/
@Override
public double cdf(int k) {
if (degenerate) {
return (k < degenerateValue) ? 0.0 : 1.0;
}
if (!truncated) return cdfUntruncated(k);
if (k < lower - 1) return 0.0;
if (k >= upper) return 1.0;
// For truncated: (Φ(k+0.5) - Φ(lower-0.5)) / Z
double num = normalCdf(k + 0.5) - normalCdf(lower - 0.5);
return num / normZ;
}
/**
* Returns the distribution mean (expected value).
*
* <p>For the rounded-normal, this is close to {@code mean}, with a tiny quantization effect. For
* the truncated variant, this is the re-normalized mean over the interval.
*
* @return {@code E[Y]}
*/
@Override
public double mean() {
return meanCached;
}
/**
* Returns the distribution variance.
*
* <p>For the rounded-normal, this is close to {@code sigma^2} plus a small quantization effect.
* For the truncated variant, this is the re-normalized variance over the interval.
*
* @return {@code Var[Y]}
*/
@Override
public double variance() {
return varianceCached;
}
/**
* Reports the mathematical support (domain) of this distribution.
*
* <p>For untruncated instances the support is the unbounded integer line. For truncated instances
* it is the closed interval {@code [lower, upper]}.
*
* @return a {@link DistributionSupport} describing the support (discrete)
*/
public DistributionSupport support() {
if (!truncated) {
return DistributionSupport.discrete(
Double.NEGATIVE_INFINITY, false,
Double.POSITIVE_INFINITY, false);
}
return DistributionSupport.discrete(lower, true, upper, true);
}
// ---------- Internal helpers ----------
/**
* Validates constructor parameters.
*
* @param mean the mean parameter (finite)
* @param sigma the standard deviation parameter (finite and {@code > 0})
* @throws IllegalArgumentException if parameters are invalid
*/
private static void validateParams(double mean, double sigma) {
if (!Double.isFinite(mean)) throw new IllegalArgumentException("mean must be finite");
if (!Double.isFinite(sigma) || !(sigma > 0.0)) {
throw new IllegalArgumentException("sigma must be finite and > 0");
}
}
/**
* PMF for the <em>untruncated</em> rounded-normal at integer {@code k}.
*
* @param k integer support point
* @return {@code P(Y = k)} before any truncation is applied
*/
private double pmfUntruncated(int k) {
double a = normalCdf(k + 0.5);
double b = normalCdf(k - 0.5);
double p = a - b;
return (p <= 0.0) ? 0.0 : p; // guard against tiny negatives from roundoff
}
/**
* CDF for the <em>untruncated</em> rounded-normal at integer {@code k}.
*
* @param k integer support point
* @return {@code P(Y ≤ k)} before any truncation is applied
*/
private double cdfUntruncated(int k) {
return normalCdf(k + 0.5);
}
/**
* Standard normal CDF {@code Φ(x)} using an {@code erf} approximation (no dependency on {@code
* Math.erf} for broad JDK compatibility).
*
* @param x standardized value
* @return {@code Φ(x)}
*/
private double normalCdf(double x) {
double z = (x - meanParam) / sigma;
return 0.5 * (1.0 + erfApprox(z * INV_SQRT2));
}
/**
* Approximation to {@code erf(x)} with max absolute error ≈ 1.5e−7.
*
* <p>From Abramowitz & Stegun 7.1.26 (Horner form).
*
* @param x input
* @return approximate {@code erf(x)}
*/
private static double erfApprox(double x) {
double sign = Math.signum(x);
double ax = Math.abs(x);
double t = 1.0 / (1.0 + 0.3275911 * ax);
double y =
1.0
- (((((1.061405429 * t - 1.453152027) * t) + 1.421413741) * t - 0.284496736) * t
+ 0.254829592)
* t
* Math.exp(-ax * ax);
return sign * y;
}
/**
* Computes mean and variance for the <em>untruncated</em> rounded-normal by summing a wide window
* of integer support (starting at {@code ±8σ} and expanding if tiny tail mass remains).
*
* @return array {@code [mean, variance]}
*/
private double[] computeUntruncatedMoments() {
int a = (int) Math.floor(meanParam - 8.0 * sigma);
int b = (int) Math.ceil(meanParam + 8.0 * sigma);
double mass = 0.0, m1 = 0.0, m2 = 0.0;
for (int k = a; k <= b; k++) {
double p = pmfUntruncated(k);
mass += p;
m1 += p * k;
m2 += p * (double) k * (double) k;
}
int extend = 0;
while (mass < 1.0 - 1e-12 && extend < 32) {
a -= 1;
b += 1;
double pLeft = pmfUntruncated(a);
double pRight = pmfUntruncated(b);
mass += pLeft + pRight;
m1 += pLeft * a + pRight * b;
m2 += pLeft * (double) a * (double) a + pRight * (double) b * (double) b;
extend++;
}
if (mass > 0.0 && Math.abs(1.0 - mass) > 1e-15) {
m1 /= mass;
m2 /= mass;
}
double var = m2 - m1 * m1;
if (var < 0.0) var = 0.0; // numeric guard
return new double[] {m1, var};
}
/**
* Computes mean and variance for the <em>truncated</em> rounded-normal by summing exactly over
* {@code k = lower..upper} with the renormalization constant {@code Z}.
*
* @param lower inclusive lower bound
* @param upper inclusive upper bound
* @param z renormalization constant {@code Z = P(lower ≤ Y ≤ upper)}
* @return array {@code [mean, variance]}
*/
private double[] computeTruncatedMoments(int lower, int upper, double z) {
double m1 = 0.0, m2 = 0.0;
for (int k = lower; k <= upper; k++) {
double p = pmfUntruncated(k) / z;
m1 += p * k;
m2 += p * (double) k * (double) k;
}
double var = m2 - m1 * m1;
if (var < 0.0) var = 0.0;
return new double[] {m1, var};
}
/**
* Draws a standard {@code Normal(0,1)} using the Marsaglia polar method, with a cached "spare"
* sample so every other call is essentially free.
*
* @return a standard normal sample
*/
private double nextGaussian() {
if (haveSpare) {
haveSpare = false;
return spare;
}
double u, v, s;
do {
u = 2.0 * rng.nextDouble() - 1.0; // (-1,1)
v = 2.0 * rng.nextDouble() - 1.0; // (-1,1)
s = u * u + v * v;
} while (s >= 1.0 || s == 0.0);
double mul = Math.sqrt(-2.0 * Math.log(s) / s);
spare = v * mul;
haveSpare = true;
return u * mul;
}
}