local statistics = {}
local ROOT_2 = math.sqrt(2.0)

-- Approximations for erf(x) and erfInv(x) from
-- https://en.wikipedia.org/wiki/Error_function

local erf
local erf_inv

local A = 8 * (math.pi - 3.0)/(3.0 * math.pi * (4.0 - math.pi))
local B = 4.0 / math.pi
local C = 2.0/(math.pi * A)
local D = 1.0 / A

erf = function(x)

	if x == 0 then return 0; end

	local xSq  = x * x
	local aXSq = A * xSq
	local v = math.sqrt(1.0 - math.exp(-xSq * (B + aXSq) / (1.0 + aXSq)))

	return (x > 0 and v) or -v
end

erf_inv = function(x)

	if x == 0 then return 0; end

	if x <= -1 or x >= 1 then return nil; end

	local y = math.log(1 - x * x)
	local u = C + 0.5 * y
	local v = math.sqrt(math.sqrt(u * u - D * y) - u)

	return (x > 0 and v) or -v
end

local function std_normal(u)
	return ROOT_2 * erf_inv(2.0 * u - 1.0)
end

local poisson
local cdf_table = {}

local function generate_cdf(lambda_index, lambda)

	local max = math.ceil(4 * lambda)
	local pdf = math.exp(-lambda)
	local cdf = pdf
	local t = { [0] = pdf }

	for i = 1, max - 1 do
		pdf = pdf * lambda / i
		cdf = cdf + pdf
		t[i] = cdf
	end

	return t
end

for li = 1, 100 do
	cdf_table[li] = generate_cdf(li, 0.25 * li)
end

poisson = function(lambda, max)

	if max < 2 then
		return (math.random() < math.exp(-lambda) and 0) or 1
	elseif lambda >= 2 * max then
		return max
	end

	local u = math.random()
	local lambda_index = math.floor(4 * lambda + 0.5)
	local cdfs = cdf_table[lambda_index]

	if cdfs then

		lambda = 0.25 * lambda_index

		if u < cdfs[0] then return 0; end
		if max > #cdfs then max = #cdfs + 1 else max = math.floor(max); end
		if u >= cdfs[max - 1] then return max; end

		if max > 4 then  -- Binary search

			local s = 0

			while s + 1 < max do

				local m = math.floor(0.5 * (s + max))

				if u < cdfs[m] then max = m; else s = m; end
			end
		else
			for i = 1, max - 1 do
				if u < cdfs[i] then return i; end
			end
		end

		return max
	else
		local x = lambda + math.sqrt(lambda) * std_normal(u)

		return (x < 0.5 and 0) or (x >= max - 0.5 and max) or math.floor(x + 0.5)
	end
end

-- Error function.
statistics.erf = erf

-- Inverse error function.
statistics.erf_inv = erf_inv

--- Standard normal distribution function (mean 0, standard deviation 1).
 --
 -- @return
 --    Any real number (actually between -3.0 and 3.0).

statistics.std_normal = function()

	local u = math.random()

	if u < 0.001 then
		return -3.0
	elseif u > 0.999 then
		return 3.0
	end

	return std_normal(u)
end

--- Standard normal distribution function (mean 0, standard deviation 1).
 --
 -- @param mu
 --    The distribution mean.
 -- @param sigma
 --    The distribution standard deviation.
 -- @return
 --    Any real number (actually between -3*sigma and 3*sigma).

statistics.normal = function(mu, sigma)

	local u = math.random()

	if u < 0.001 then
		return mu - 3.0 * sigma
	elseif u > 0.999 then
		return mu + 3.0 * sigma
	end

	return mu + sigma * std_normal(u)
end

--- Poisson distribution function.
 --
 -- @param lambda
 --    The distribution mean and variance.
 -- @param max
 --    The distribution maximum.
 -- @return
 --    An integer between 0 and max (both inclusive).

statistics.poisson = function(lambda, max)

	lambda, max = tonumber(lambda), tonumber(max)

	if not lambda or not max or lambda <= 0 or max < 1 then return 0; end

	return poisson(lambda, max)
end

return statistics