import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from itertools import combinations
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import MaxNLocator
import sys
from .._configs import *
from .numerical_utils import gaussian_kde
if __useTeX__:
plt.rcParams.update({
"text.usetex": True,
"mathtext.fontset": "cm",
#"font.family": "serif",
#"font.serif": ["Computer Modern Roman"]
"font.family": "sans-serif",
"font.sans-serif": ["Computer Modern Serif"]
})
__all__ = ["plot_states", "plot_regularization_path", "plot_classifier_complexity_vs_accuracy", "plot_combination_states_features", "plot_states_features", "plot_histogram_features" ]
##########################################################################
## FESSA COLOR PALETTE
# https://github.com/luigibonati/fessa-color-palette/blob/master/fessa.py
##########################################################################
from matplotlib.colors import LinearSegmentedColormap, ColorConverter
from matplotlib.cm import register_cmap
paletteFessa = [
'#1F3B73', # dark-blue
'#2F9294', # green-blue
'#50B28D', # green
'#A7D655', # pisello
'#FFE03E', # yellow
'#FFA955', # orange
'#D6573B', # red
]
cm_fessa = LinearSegmentedColormap.from_list('fessa', paletteFessa)
register_cmap(cmap=cm_fessa)
register_cmap(cmap=cm_fessa.reversed())
for i in range(len(paletteFessa)):
ColorConverter.colors[f'fessa{i}'] = paletteFessa[i]
### To set it as default
# import fessa
# plt.set_cmap('fessa')
### or the reversed one
# plt.set_cmap('fessa_r')
### For contour plots
# plt.contourf(X, Y, Z, cmap='fessa')
### For standard plots
# plt.plot(x, y, color='fessa0')
##########################################################################
# aux function to compute basins mean
def compute_basin_mean(df, basin, label_x, label_y):
mx = df[df['basin'] == basin][label_x].mean()
my = df[df['basin'] == basin][label_y].mean()
return mx,my
[docs]def plot_regularization_path(classifier, reg):
assert classifier._computed, "You have to run Classifier.compute first."
reg_idx = classifier._closest_reg_idx(reg)
n_basins = classifier._coeffs.shape[1]
rows = np.int(np.ceil((n_basins + 1)/3))
fig = plt.figure(constrained_layout=True, figsize=(8,2.5*rows))
gs = GridSpec(rows, 3, figure=fig)
axes = []
for basin_idx in range(n_basins):
axes.append(fig.add_subplot(gs[np.unravel_index(basin_idx, (rows, 3))]))
axes.append(fig.add_subplot(gs[np.unravel_index(n_basins, (rows, 3))]))
fig.suptitle(r"Regularization paths")
for idx, state_idx in enumerate(classifier._classes_labels[reg_idx]):
ax = axes[idx]
_cfs = classifier._coeffs[:,idx,:]
killer = np.abs(np.sum(_cfs, axis=0)) >= __EPS__
ax.plot(np.log10(classifier._reg), _cfs[:,killer], 'k-')
ax.axvline(x = np.log10(classifier._reg[reg_idx]), color='tomato', linewidth=0.75)
ax.set_xmargin(0)
ax.set_xlabel(r"$\log_{10}(\lambda)$")
ax.set_title(classifier.classes[state_idx])
ax = axes[-1]
ax.plot(np.log10(classifier._reg), classifier._crossval, 'k-')
ax.axvline(x = np.log10(classifier._reg[reg_idx]), color='r', linewidth=0.75)
ax.set_xmargin(0)
ax.set_ylim(0,1.1)
ax.set_xlabel(r"$\log_{10}(\lambda)$")
ax.set_title(r"Accuracy")
return fig, axes
[docs]def plot_classifier_complexity_vs_accuracy(classifier, feature_mode = False, ax = None):
assert classifier._computed, "You have to run Classifier.compute first."
num_groups = []
for reg in classifier._reg:
selected = classifier._get_selected(reg, feature_mode=feature_mode)
unique_idxs = set()
for state in selected.values():
for data in state:
unique_idxs.add(data[0])
num_groups.append(len(unique_idxs))
if ax is not None:
ax1 = ax
else:
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax2.grid(alpha=0.3)
ax2.plot(np.log10(classifier._reg), num_groups, '--', color='fessa1')
ax1.plot(np.log10(classifier._reg), classifier._crossval, '-', color='fessa0')
ax2.yaxis.set_major_locator(MaxNLocator(integer=True))
ax1.set_xlabel(r"$\log_{10}(\lambda)$")
ax1.set_ylabel('Accuracy', color='fessa0')
ax1.set_ylim(0,1.1)
desc = "Groups" if classifier._groups is not None else "Features"
ax2.set_ylabel(f'Number of {desc}', color='fessa1')
ax1.set_xmargin(0)
if ax is not None:
return (ax1, ax2)
else:
return fig, (ax1, ax2)
[docs]def plot_states(colvar, state_labels, selected_cvs, fes_isolines = False, n_iso_fes = 9, ev_iso_labels = 2, alpha=0.3, cmap_name = 'Set2', save_folder=None, axs = None, **kde_kwargs):
states = state_labels['labels'].unique()
n_states = len(states)
# hexbin plot of tica components
idxs_pairs = [p for p in combinations(np.arange(len(selected_cvs)), 2)]
n_pairs = len(idxs_pairs)
if axs is None:
fig, axs = plt.subplots(1,n_pairs,figsize=(4.8*n_pairs,4), dpi=100)
for k, (x_idx,y_idx) in enumerate(idxs_pairs):
label_x = selected_cvs[x_idx]
label_y = selected_cvs[y_idx]
# select ax
ax = axs[k] if n_pairs > 1 else axs
# FES isolines (if 2D)
if fes_isolines:
#logweights = None
#bw_method = 0.15
num_samples = 100
cmap = matplotlib.cm.get_cmap('Greys_r', n_iso_fes)
color_list = [cmap((i+1)/(n_iso_fes+3)) for i in range(n_iso_fes)]
empirical_centers = colvar[[label_x,label_y]].to_numpy()
KDE = gaussian_kde(empirical_centers,**kde_kwargs)
bounds = [(x.min(), x.max()) for x in KDE.dataset.T]
mesh = np.meshgrid(*[np.linspace(b[0], b[1], num_samples) for b in bounds])
positions = np.vstack([g.ravel() for g in mesh]).T
fes = -KDE.logpdf(positions).reshape(num_samples,num_samples)
fes -= fes.min()
CS = ax.contour(*mesh, fes, levels=np.linspace(0,n_iso_fes-1,n_iso_fes), colors = color_list)
ax.clabel(CS, CS.levels[::ev_iso_labels], fmt = lambda x: str(int(x))+ r'$k_{{\rm B}}T$', inline=True, fontsize=8)
# Hexbin plot
x = colvar[label_x]
y = colvar[label_y]
z = state_labels['labels']
sel = state_labels['selection']
not_sel = np.logical_not(sel)
cmap = matplotlib.cm.get_cmap(cmap_name, n_states)
color_list = [cmap(i/(n_states)) for i in range(n_states)]
ax.hexbin(x[not_sel],y[not_sel],C=z[not_sel],cmap=cmap_name,alpha=alpha)
ax.hexbin(x[sel],y[sel],C=z[sel],cmap=cmap_name)
if axs is None:
ax.set_title('Metastable states identification')
ax.set_xlabel(label_x)
ax.set_ylabel(label_y)
#Add basins labels
for b in states:
mask = np.logical_and(sel, z == b)
#If weighted not ok but functional
mx,my = np.mean(x[mask]), np.mean(y[mask])
ax.scatter(mx,my,color='w',s=300,alpha=0.7)
_ = ax.text(mx, my, b, ha="center", va="center", color='k', fontsize='large')
if save_folder is not None:
plt.savefig(save_folder+'states.pdf',bbox_inches='tight')
if axs is None:
plt.tight_layout()
return fig, axs
[docs]def plot_combination_states_features(colvar, descriptors, selected_cvs, relevant_features, state_labels = None, save_folder=None, file_prefix='linear'):
if len(selected_cvs) < 2:
raise NotImplementedError('This plot is available only when selecting 2 or more CVs.')
added_columns = False
#Handle quadratic kernels
for _state in relevant_features.values():
for _feat_tuple in _state:
feature = _feat_tuple[2]
if "||" in feature:
if feature not in descriptors.columns:
added_columns = True
i, j = feature.split(' || ')
feat_ij = descriptors[i].values * descriptors[j].values
descriptors[feature] = feat_ij
if added_columns:
print("Warning: detected quadratic kenel features, added quadratic features to the input dataframe", file=sys.stderr)
pairs = combinations(selected_cvs, 2)
n_pairs = sum(1 for _ in pairs)
for k,(label_x,label_y) in enumerate(combinations(selected_cvs, 2)):
cv_x = colvar[label_x].values
cv_y = colvar[label_y].values
plot_states_features(cv_x, cv_y, descriptors, relevant_features, state_labels=state_labels, max_nfeat = 3)
if save_folder is not None:
plt.savefig(save_folder+file_prefix+f'-relevant_feats{k+1 if n_pairs > 1 else None}.png',
facecolor='w',
transparent=False,
bbox_inches='tight')
[docs]def plot_states_features(cv_x, cv_y, descriptors, relevant_feat, state_labels = None, max_nfeat = 3):
n_basins = len(relevant_feat)
# if state_labels are given plot only selection
if state_labels is not None:
mask = state_labels['selection']
cv_x = cv_x[mask]
cv_y = cv_y[mask]
descriptors = descriptors[mask]
state_labels = state_labels[mask]
fig, axs = plt.subplots(n_basins,max_nfeat,figsize=(4 * max_nfeat, 3.5* n_basins),
sharex=True, sharey=True)
# for each state ...
for i, (state_name, feat_list) in enumerate( relevant_feat.items() ):
state = i
# ... color with the corresponding features ...
for j,feat_array in enumerate(feat_list):
# ... up to max_nfeat plot per state
if j < max_nfeat:
feat = feat_array[2]
importance = feat_array[1]
ax = axs[i,j]
#pp = df[df['selection']==1].plot.hexbin(cv_x,cv_y,C=feat,cmap='coolwarm',ax=ax)
pp = ax.hexbin(cv_x,cv_y,C=descriptors[feat],cmap='coolwarm')
#set title
if (__useTeX__) and ('_' in feat):
feat = feat.replace('_','\_')
ax.set_title(f'[{state}: {state_name}] {feat} - {np.round(importance*100)}%')
#add basins labels if given
if state_labels is not None:
states = state_labels['labels'].unique()
z = state_labels['labels']
#Add basins labels
for b in states:
#mask = np.logical_and(sel, z == b)
mask = ( z == b )
#If weighted not ok but functional
mx,my = np.mean(cv_x[mask]), np.mean(cv_y[mask])
ax.scatter(mx,my,color='w',s=300,alpha=0.7)
_ = ax.text(mx, my, b, ha="center", va="center", color='k', fontsize='large')
[docs]def plot_histogram_features(descriptors,states_labels,classes_names,relevant_feat, hist_offset = -0.2, n_bins = 50, ylog = False, axs = None, height=1, width=6, colors=None):
#TODO MOVE PLOT KEYWORDS inTO DICT
features_per_class = [ len(feat_list) for feat_list in relevant_feat.values() ]
added_columns = False
#Handle quadratic kernels
for _state in relevant_feat.values():
for _feat_tuple in _state:
feature = _feat_tuple[2]
if "||" in feature:
if feature not in descriptors.columns:
added_columns = True
i, j = feature.split(' || ')
feat_ij = descriptors[i].values * descriptors[j].values
descriptors[feature] = feat_ij
if added_columns:
print("Warning: detected quadratic kenel features, added quadratic features to the input dataframe", file=sys.stderr)
if axs is None:
tight=True
fig,axs = plt.subplots(len(relevant_feat), 1,
figsize=(width, sum(features_per_class)*height ),
gridspec_kw={'height_ratios': features_per_class})
if len(relevant_feat) == 1:
axs = [axs]
else:
tight=False
#for b, (basin, basin_name) in enumerate( classes_names.items() ):
for b, (basin_name,feat_list) in enumerate(relevant_feat.items()) :
def get_key(dict, val):
for key, value in dict.items():
if val == value:
return key
basin = get_key(classes_names,basin_name)
#feat_list = relevant_feat[basin_name]
#fig,ax = plt.subplots( figsize = (4,0.5*len(feat_list)) )
ax = axs[b]
feature_labels = []
for h, feature in enumerate( feat_list[::-1] ):
feature_name = feature[2]
if (__useTeX__) and ('_' in feature_name):
feature_label = feature_name.replace('_','\_')
else:
feature_label = feature_name
feature_labels.append(feature_label)
#coordinate = descriptors[feature_name]
#hist, edges = np.histogram(coordinate, bins=n_bins)
for i in classes_names.keys():
x_i = descriptors[ ( states_labels['labels'] == classes_names[i] ) & ( states_labels['selection'] ) ][feature_name]
hist, edges = np.histogram(x_i, bins=n_bins)
if not ylog:
y = hist / hist.max()
else:
y = np.zeros_like(hist) + np.NaN
pos_idx = hist > 0
y[pos_idx] = np.log(hist[pos_idx]) / np.log(hist[pos_idx]).max()
if colors is not None:
color = colors[i]
else:
color = f'fessa{6-i}' #'tab:red' if basin == i else 'dimgray'
ax.plot(edges[:-1], y + h + hist_offset,color=color)
ax.fill_between(edges[:-1], y + h + hist_offset, y2=h + hist_offset, color=color, alpha=0.5) #, **kwargs)
ax.axhline(y=h + hist_offset, xmin=0, xmax=1, color='k', linewidth=.2)
ax.set_ylim(hist_offset, h + hist_offset + 1)
# formatting
if feature_labels is None:
feature_labels = [str(n) for n in range(len(feat_list))]
ax.set_ylabel('Feature histograms')
ax.set_yticks(np.array(range(len(feature_labels))) + .3)
ax.set_yticklabels(feature_labels)
#ax.set_xlabel('Feature values')
ax.set_title(f'{basin}: {basin_name}')
if tight:
plt.tight_layout()
def plot_fes(cv,bandwidth,states_labels=None,logweights=None,kBT=2.5,cv_list=None,states_subset=None,num_samples=100,ax=None,prefix_label="",colors=None):
if cv_list is not None:
cv = cv[cv_list]
empirical_centers = cv.to_numpy()
KDE = gaussian_kde(empirical_centers,bandwidth,logweights)
bounds = [(x.min(), x.max()) for x in KDE.dataset.T]
mesh = np.meshgrid(*[np.linspace(b[0], b[1], num_samples) for b in bounds])
positions = np.vstack([g.ravel() for g in mesh]).T
fes = -kBT*KDE.logpdf(positions)
fes -= fes.min()
if ax is None:
fig,ax = plt.subplots()
ax.plot(mesh[0],fes/kBT,color='dimgrey',linewidth=1.5)
ax.set_xlabel(cv.columns.values[0])
ax.set_ylabel('FES [$k_B$T]')
ax.set_xlim(bounds[0][0],bounds[0][1])
ax.set_ylim(0,)
if states_labels is not None:
if states_subset is not None:
labels = states_subset
else:
labels = sorted(states_labels['labels'].unique())
for i,label in enumerate(labels):
mask = ( states_labels['labels'] == label ) & (states_labels['selection'] == True )
Min = cv[mask].min().values[0]
Max = cv[mask].max().values[0]
if colors is not None:
color = colors[i]
else:
color = f'fessa{6-i}'
ax.axvspan(Min,Max, alpha=0.5, color=color)
ax.text((Max+Min)/2,5,prefix_label+str(label),fontsize='medium',ha='center')
def plot_fes_2d(colvar, state_labels, selected_cvs, n_iso_fes = 10, ev_iso_labels = 2, save_folder=None, ax = None, xlim=[-1,1.], ylim=[-1,1.], label_names = None, label_colors= None, **kde_kwargs):
states = state_labels['labels'].unique()
xlim=[-1,1.05]
ylim=[-1,1.05]
label_x = selected_cvs[0]
label_y = selected_cvs[1]
# FES ISOLINES
num_samples = 100
cmap = matplotlib.cm.get_cmap('Greys_r', n_iso_fes)
color_list = [cmap((i+1)/(n_iso_fes+3)) for i in range(n_iso_fes)]
empirical_centers = colvar[[label_x,label_y]].to_numpy()
KDE = gaussian_kde(empirical_centers,**kde_kwargs)
bounds = [(x.min(), x.max()) for x in KDE.dataset.T]
mesh = np.meshgrid(*[np.linspace(b[0], b[1], num_samples) for b in bounds])
positions = np.vstack([g.ravel() for g in mesh]).T
fes = -KDE.logpdf(positions).reshape(num_samples,num_samples)
fes -= fes.min()
CS = ax.contour(*mesh, fes, levels=np.linspace(0,n_iso_fes-1,n_iso_fes), colors = color_list)
ax.clabel(CS, CS.levels[::ev_iso_labels], fmt = lambda x: str(int(x))+ r'$k_{{\rm B}}T$', inline=True, fontsize=8)
# Add basins labels
x = colvar[label_x]
y = colvar[label_y]
z = state_labels['labels']
sel = state_labels['selection']
if label_names is None:
label = states
else:
label = label_names
if label_colors is None:
color=[paletteFessa[i] for i in range(len(states))]
else:
color = label_colors
for b in states:
mask = np.logical_and(sel, z == b)
mx,my = np.average(x[mask],weights=np.exp(kde_kwargs['logweights'][mask])), np.average(y[mask],weights=np.exp(kde_kwargs['logweights'][mask]))
#ax.scatter(mx,my,color=color_list[b],s=300,alpha=1)
#if b>0:
# ax.scatter(mx,my,color=color[b],s=550,alpha=0.5,facecolors=None,edgecolors=paletteFessa[6])
ax.scatter(mx,my,color=color[b],s=450,alpha=0.5,edgecolors=None)
_ = ax.text(mx, my, label[b], ha="center", va="center", color='k', fontsize='large')
ax.set_xlabel(selected_cvs[0])
ax.set_ylabel(selected_cvs[1])