Source code for susi.SOMPlots

"""SOMPlots functions."""

from typing import List, Tuple

import matplotlib
import matplotlib.pyplot as plt
import numpy as np


[docs] def plot_estimation_map( estimation_map: np.ndarray, cbar_label: str = "Variable in unit", cmap: str = "viridis", fontsize: int = 20, ) -> plt.Axes: """Plot estimation map. Parameters ---------- estimation_map : np.ndarray Estimation map of the size (n_rows, n_columns) cbar_label : str, optional Label of the colorbar, by default "Variable in unit" cmap : str, optional (default="viridis") Colormap fontsize : int, optional (default=20) Fontsize of the labels Returns ------- ax : pyplot.axis Plot axis """ _, ax = plt.subplots(1, 1, figsize=(7, 5)) img = ax.imshow(estimation_map, cmap=cmap) ax.set_xlabel("SOM columns", fontsize=fontsize) ax.set_ylabel("SOM rows", fontsize=fontsize) # ax.set_xticklabels(fontsize=fontsize) # ax.set_yticklabels(fontsize=fontsize) ax.tick_params(axis="both", which="major", labelsize=fontsize) # colorbar cbar = plt.colorbar(img, ax=ax) cbar.ax.tick_params(labelsize=fontsize) cbar.ax.set_ylabel(cbar_label, fontsize=fontsize, labelpad=10) for label in cbar.ax.xaxis.get_ticklabels()[::2]: label.set_visible(False) plt.grid(visible=False) return ax
[docs] def plot_som_histogram( bmu_list: List[Tuple[int, int]], n_rows: int, n_columns: int, n_datapoints_cbar: int = 5, fontsize: int = 22, ) -> plt.Axes: """Plot 2D Histogram of SOM. Plot 2D Histogram with one bin for each SOM node. The content of one bin is the number of datapoints matched to the specific node. Parameters ---------- bmu_list : list of (int, int) tuples Position of best matching units (row, column) for each datapoint n_rows : int Number of rows for the SOM grid n_columns : int Number of columns for the SOM grid n_datapoints_cbar : int, optional (default=5) Maximum number of datapoints shown on the colorbar fontsize : int, optional (default=22) Fontsize of the labels Returns ------- ax : pyplot.axis Plot axis """ fig, ax = plt.subplots(1, 1, figsize=(6, 6)) # colormap cmap = plt.cm.viridis cmaplist = [cmap(i) for i in range(cmap.N)] cmap = matplotlib.colors.LinearSegmentedColormap.from_list( "mcm", cmaplist, cmap.N ) bounds = np.arange(0.0, n_datapoints_cbar + 1, 1.0) norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N) ax2 = fig.add_axes([0.96, 0.12, 0.03, 0.76]) cbar = matplotlib.colorbar.ColorbarBase( ax2, cmap=cmap, norm=norm, spacing="proportional", ticks=bounds, boundaries=bounds, format="%1i", extend="max", ) cbar.ax.set_ylabel("Number of datapoints", fontsize=fontsize) cbar.ax.tick_params(labelsize=fontsize) ax.hist2d( [x[0] for x in bmu_list], [x[1] for x in bmu_list], bins=[n_rows, n_columns], cmin=1, cmap=cmap, norm=norm, ) for label in cbar.ax.xaxis.get_ticklabels()[::2]: label.set_visible(False) ax.set_xlabel("SOM columns", fontsize=fontsize) ax.set_ylabel("SOM rows", fontsize=fontsize) for tick in ax.xaxis.get_major_ticks(): tick.label.set_fontsize(fontsize) for tick in ax.yaxis.get_major_ticks(): tick.label.set_fontsize(fontsize) # to be compatible with plt.imshow: ax.invert_yaxis() plt.grid(visible=False) return ax
[docs] def plot_umatrix( u_matrix: np.ndarray, n_rows: int, n_colums: int, cmap: str = "Greys", fontsize: int = 18, ) -> plt.Axes: """Plot u-matrix. Parameters ---------- u_matrix : np.ndarray U-matrix containing the distances between all nodes of the unsupervised SOM. Shape = (n_rows*2-1, n_columns*2-1) n_rows : int Number of rows for the SOM grid n_columns : int Number of columns for the SOM grid cmap : str, optional (default="Greys) Colormap fontsize : int, optional (default=18) Fontsize of the labels Returns ------- ax : pyplot.axis Plot axis """ _, ax = plt.subplots(figsize=(6, 6)) img = ax.imshow(u_matrix.squeeze(), cmap=cmap) ax.set_xticks(np.arange(0, n_colums * 2 + 1, 20)) ax.set_xticklabels(np.arange(0, n_colums + 1, 10)) ax.set_yticks(np.arange(0, n_rows * 2 + 1, 20)) ax.set_yticklabels(np.arange(0, n_rows + 1, 10)) # ticks and labels for tick in ax.xaxis.get_major_ticks(): tick.label.set_fontsize(fontsize) for tick in ax.yaxis.get_major_ticks(): tick.label.set_fontsize(fontsize) ax.set_ylabel("SOM rows", fontsize=fontsize) ax.set_xlabel("SOM columns", fontsize=fontsize) # colorbar cbar = plt.colorbar(img, ax=ax, fraction=0.04, pad=0.04) cbar.ax.set_ylabel( "Distance measure (a.u.)", rotation=90, fontsize=fontsize, labelpad=20 ) cbar.ax.tick_params(labelsize=fontsize) return ax
[docs] def plot_nbh_dist_weight_matrix(som, it_frac: float = 0.1) -> plt.Axes: """Plot neighborhood distance weight matrix in 3D. Parameters ---------- som : susi.SOMClustering or related Trained (un)supervised SOM it_frac : float, optional (default=0.1) Fraction of `som.n_iter_unsupervised` for the plot state. Returns ------- ax : pyplot.axis Plot axis """ nbh_func = som._calc_neighborhood_func( curr_it=som.n_iter_unsupervised * it_frac, mode=som.neighborhood_mode_unsupervised, ) dist_weight_matrix = som._get_nbh_distance_weight_matrix( neighborhood_func=nbh_func, bmu_pos=[som.n_rows // 2, som.n_columns // 2], ) fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111, projection="3d") x = np.arange(som.n_rows) y = np.arange(som.n_columns) X, Y = np.meshgrid(x, y) Z = dist_weight_matrix.reshape(som.n_rows, som.n_columns) surf = ax.plot_surface( X, Y, Z, cmap=matplotlib.cm.coolwarm, antialiased=False, rstride=1, cstride=1, linewidth=0, ) fig.colorbar(surf, shrink=0.5, aspect=10) return ax