Example
Here we demonstrate the usage of RedClust through an example. This example can be downloaded as a Julia script here.
We begin by setting up the necessary includes.
using RedClust, Plots, StatsPlots
using Random: seed!
using StatsBase: counts
using LinearAlgebra: triu, diagind
Next we define some convenience functions for plotting.
# Heatmap of square matrix
function sqmatrixplot(X::Matrix; kwargs...)
M, N = size(X)
heatmap(
X,
aspect_ratio=:equal,
color=:Blues,
xlim=(1,M), ylim=(1,N),
yflip = true, xmirror=true;
kwargs...)
end
# Histogram with integer bins
function histogram_pmf(X::AbstractVector{<:Integer}; binwidth::Int=1, kwargs...)
xmin = minimum(X)
xmax = maximum(X)
binedges = xmin:binwidth:xmax
c = counts(X)./length(X)
bincounts = map(sum, [c[(i-xmin+1):minimum([i-xmin+1+binwidth-1, xmax-xmin+1])] for i in binedges])
bar(binedges, bincounts,
linewidth = 0,
legend = false,
xticks = binedges; kwargs...)
end
# Combine two symmetric square matrices together into the upper and lower triangle of a square matrix
function combine_sqmatrices(lower::Matrix, upper::Matrix, diagonal::String = "lower")
if size(lower)[1] != size(lower)[2]
throw(ArgumentError("Argument `lower` must be square, has dimensions $(size(lower))."))
end
if size(upper)[1] != size(upper)[2]
throw(ArgumentError("Argument `upper` must be a square matrix, has dimensions $(size(upper))."))
end
if !all(size(lower) .== size(upper))
throw(ArgumentError("Arguments `lower` and `upper` must have the same size."))
end
if !(eltype(lower) <: eltype(upper)) && !(eltype(upper) <: eltype(lower))
throw(ArgumentError("Arguments must have compatible entries, got $(eltype(lower)) and $(eltype(upper))."))
end
if diagonal ∉ ["lower", "upper"]
throw(ArgumentError("Keyword argument `diagonal` must be either \"lower\" or \"upper\"."))
end
result = copy(lower)
temp = trues(size(lower))
upper_idx = triu(temp, 1)
diagonal_idx = diagind(temp)
result[upper_idx] .= upper[upper_idx]
result[diagonal_idx] .= ((diagonal == "lower") ? lower : upper)[diagonal_idx]
return result
end
Generating Data
We can generate some example data using the function generatemixture
.
begin
K = 10 # Number of clusters
N = 100 # Number of points
data_σ = 0.25 # Variance of the normal kernel
data_dim = 10 # Data dimension
α = 10 # parameter for Dirichlet prior on cluster weights
data = generatemixture(N, K;
α = α, σ = data_σ, dim = data_dim)
points, distmatrix, clusts, probs, oracle_coclustering = data
end
Alternatively, the function example_dataset
can be used to retrieve the datasets used in the original RedClust paper.
begin
data = example_dataset(1)
points, distmatrix, clusts, probs, oracle_coclustering = data
end
We can visualise the true adjacency matrix of the observations with respect to the true clusters that they were drawn from, as well as the oracle coclustering matrix. The latter is the matrix of co-clustering probabilities of the observations conditioned upon the data generation process. This takes into account full information about the cluster weights (and how they are generated), the mixture kernels for each cluster, and the location and scale parameters for these kernels.
sqmatrixplot(combine_sqmatrices(oracle_coclustering, 1.0 * adjacencymatrix(clusts)), title = "Adjacency vs Oracle Co-clustering Probabilities \n(upper right and lower left triangle)\n")
We can visualise the matrix of pairwise distances between the observations.
sqmatrixplot(distmatrix, title = "Matrix of Pairwise Distances")
We can also plot the histogram of distances, grouped by whether they are inter-cluster distances (ICD) or within-cluster distances (WCD).
begin
empirical_intracluster = uppertriangle(distmatrix)[
uppertriangle(adjacencymatrix(clusts)) .== 1]
empirical_intercluster = uppertriangle(distmatrix)[
uppertriangle(adjacencymatrix(clusts)) .== 0]
histogram(empirical_intercluster,
bins = minimum(empirical_intercluster):0.05:maximum(empirical_intercluster),
label="ICD", xlabel = "Distance", ylabel="Frequency",
title = "Observed distribution of distances")
histogram!(empirical_intracluster,
bins = minimum(empirical_intracluster):0.05:maximum(empirical_intracluster),
label="WCD")
end
Prior Hyperparameters
RedClust includes the function fitprior
to heuristically choose prior hyperparameters based on the data.
params = fitprior(points, "k-means", false)
Model Hyperparameters
Likelihood Hyperparameters
δ₁ = 17.380
δ₂ = 27.626
α = 7421.121
β = 469.274
ζ = 124952.409
γ = 8099.269
Partition Prior Hyperparameters
η = 5.746
σ = 1.810
u = 16.546
v = 6.890
Miscellaneous Hyperparameters
Proposal standard deviation for sampling r = 1.380
Repulsion is enabled? true
Maximum number of clusters = none
Initial number of clusters = 12
We can check how good the chosen prior hyperparameters are by comparing the empirical distribution of distances to the (marginal) prior predictive distribution.
begin
pred_intracluster = sampledist(params, "intracluster", 10000)
pred_intercluster = sampledist(params, "intercluster", 10000)
density(pred_intercluster,
label="Simulated ICD", xlabel = "Distance", ylabel = "Density",
linewidth = 2, linestyle = :dash)
density!(empirical_intercluster,
label="Empirical ICD",
color = 1, linewidth = 2)
density!(pred_intracluster,
label="Simulated WCD",
linewidth = 2, linestyle = :dash, color = 2)
density!(empirical_intracluster,
label="Empirical WCD",
linewidth = 2, color = 2)
title!("Distances: Prior Predictive vs Empirical Distribution")
end
We can also evaluate the prior hyperparameters by checking the marginal predictive distribution on $K$ (the number of clusters).
begin
Ksamples = sampleK(params, 10000, N)
histogram_pmf(Ksamples, legend = false,
xticks=collect(0:10:maximum(Ksamples)),
xlabel = "\$K\$", ylabel = "Probability", title = "Marginal Prior Predictive Distribution of \$K\$")
end
Sampling
Running the MCMC is straightforward. We set up the MCMC options using MCMCOptionsList
.
options = MCMCOptionsList(numiters=50000)
MCMC Options
50000 iterations
10000 burnin iterations
40000 samples
5 restricted Gibbs steps per split-merge step
1 split-merge step per iteration
We then set up the input data using MCMCData
.
data = MCMCData(points)
MCMC data : 100×100 dissimilarity matrix.
We can then run the sampler using runsampler
.
result = runsampler(data, options, params)
MCMC Summary
General
50000 iterations
10000 iterations discarded as burnin
40000 samples
5 restricted Gibbs steps per split-merge step
1 split-merge step per iteration
Acceptance rate for split-merge steps = 1.040e-03
Acceptance rate for sampling r = 3.808e-01
Using repulsion: true
Maximum number of clusters allowed: any
Runtime = 1 min 32 s
Time per iteration : 1.84 ms
Summary for K
IAC : 6.240
ESS : 6410.015
ESS per sample : 1.603e-01
Posterior mean : 13.692
Posterior variance : 6.400e-01
Summary for r
IAC : 17.886
ESS : 2236.442
ESS per sample : 5.591e-02
Posterior mean : 2.437
Posterior variance : 7.101e-01
Summary for p
IAC : 13.191
ESS : 3032.333
ESS per sample : 7.581e-02
Posterior mean : 7.243e-01
Posterior variance : 4.006e-03
MCMC Result
The MCMC result contains several details about the MCMC, including acceptance rate, runtime, and convergence diagnostics. For full details see MCMCResult
. In this example we have the ground truth cluster labels, so we can evaluate the result. For example, we can compare the posterior coclustering matrix to the oracle co-clustering probabilities.
sqmatrixplot(combine_sqmatrices(result.posterior_coclustering, oracle_coclustering),
title="Posterior vs Oracle Coclustering Probabilities")
Plot the posterior distribution of $K$:
histogram_pmf(result.K,
xlabel = "\$K\$", ylabel = "Probability", title = "Posterior Distribution of \$K\$")
Plot the posterior distribution of $r$:
begin
histogram(result.r, normalize = :pdf,
legend_font_pointsize=12,
label="Empirical density", ylabel = "Density", xlabel = "\$r\$",
title = "Posterior Distribution of \$r\$")
density!(result.r,
color=:black, linewidth = 2, linestyle=:dash,
label="Kernel estimate", legend_font_pointsize=12)
end
Plot the posterior distribution of $p$:
begin
histogram(result.p, normalize = :pdf,
ylabel = "Density", xlabel = "\$p\$",
title = "Posterior Distribution of \$p\$",
label = "Empirical density",
legend_font_pointsize=12)
density!(result.p, color=:black, linewidth = 2, linestyle=:dash,
label = "Kernel estimate",
legend_position = :topleft)
end
Convergence statistics
Plot the traceplot of the autocorrelation function of $K$:
plot(result.K_acf, legend = false, linewidth = 1,
xlabel = "Lag", ylabel = "Autocorrelation",
title = "Autocorrelation Function of \$K\$")
Plot the traceplot of the autocorrelation function of $r$:
plot(result.r_acf, legend = false, linewidth = 1,
xlabel = "Lag", ylabel = "Autocorrelation",
title = "Autocorrelation Function of \$r\$")
Plot the traceplot of the autocorrelation function of $p$:
plot(result.p_acf, legend = false, linewidth = 1,
xlabel = "Lag", ylabel = "Autocorrelation",
title = "Autocorrelation Function of \$p\$")
Check the trace plot of the log-likelihood to make sure the MCMC is moving well:
plot(result.loglik, legend = false, linewidth = 1,
xlabel = "Iteration", ylabel = "Log likelihood",
title = "Log-Likelihood Trace Plot")
Check the trace plot of the log-posterior:
plot(result.logposterior, legend = false, linewidth = 1,
xlabel = "Iteration", ylabel = "Log posterior",
title = "Log-Posterior Trace Plot")
Point Estimates
The function getpointestimate
finds an optimal point estimate, based on some notion of optimality. For example, to get the maximum a posteriori estimate we can run the following.
pointestimate, index = getpointestimate(result; method="MAP")
([1, 1, 1, 1, 1, 2, 1, 1, 1, 3 … 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], 9716)
We can compare the point-estimate to the true clustering through their adjacency matrices.
sqmatrixplot(combine_sqmatrices(adjacencymatrix(pointestimate), adjacencymatrix(clusts)),
title = "True Clustering vs MAP Point Estimate")
We can check the accuracy of the point estimate in terms of clustering metrics.
summarise(pointestimate, clusts)
Clustering summary
Number of clusters : 13
Normalised Binder loss : 0.01616161616161616
Adjusted Rand Index : 0.9028187471471413
Normalised Variation of Information (NVI) distance : 0.06086895626900138
Normalised Information Distance (NID) : 0.0429266154915723
Normalised Mutual Information : 0.9400474997102378
This page was generated using Literate.jl.