import nglview
import matplotlib
import matplotlib.cm as cm
import numpy as np
from scipy.sparse import csr_matrix
from .._configs import *
from .plot import paletteFessa
from time import sleep
[docs]def visualize_features(trajectory, states_labels, classes_names, relevant_features, feats_info, state = 0, n_feat_per_state=3, representation = 'licorice'):
"""Visualize snapshots of each state highlighting the relevant features for a given state.
Parameters
----------
trajectory : mdtraj.Trajectory
MD trajectory
states_labels : pd.DataFrame
labels
classes_names : list
names of the classes
relevant_features : dict
features selected by Lasso
feats_info : pd.DataFrame
descriptors information (atoms involved)
state : int, optional
state for which the features are displayed, by default 0
n_feat_per_state : int, optional
number of features to be highlighted, by default 3
representation : str, optional
type of representation (licorice,cartoon,ball-and-stick), by default 'licorice'
Returns
-------
nglview viewer
View object
"""
# sample one frame per state
frames = [states_labels [( states_labels['labels'] == i ) & ( states_labels['selection'] ) ].sample(1).index.values[0] for i in classes_names.keys() ]
traj = trajectory[frames]
traj.superpose(traj[0])
# find atom ids of relevant features
atom_ids = []
features = relevant_features[ classes_names[state] ]
for i, feature in enumerate(features):
if i < n_feat_per_state:
name = feature[2]
atom_ids.append( feats_info[name]['atoms'] )
# set up visualization
view = nglview.show_mdtraj(traj, default=False)
# representation
if representation == 'licorice':
view.add_licorice('(not hydrogen)',opacity=0.35)
view.add_licorice('(backbone) and (not hydrogen)',opacity=0.85)
elif representation == 'cartoon':
view.add_cartoon('protein',opacity=0.85)
elif representation == 'ball-and-stick':
view.add_ball_and_stick('(not hydrogen)',opacity=0.15)
view.add_ball_and_stick('(backbone) and (not hydrogen)',opacity=0.5)
# colors
colors = iter(['orange','green', 'purple', 'yellow', 'red'])
# loop over relevant features
for ids in atom_ids:
ids_string = [str(p) for p in ids]
selection = '@'+','.join(ids_string)
color = next(colors)
if len(ids) == 2: # distance
#color = 'orange'
atom_pair = [ '@'+p for p in ids_string ]
view.add_distance(atom_pair=[atom_pair], color=color, label_visible=False)
view.add_ball_and_stick(selection,color=color,opacity=0.75)
elif len(ids) == 4: # angle
#color = 'green'
view.add_ball_and_stick(selection,color=color,opacity=0.75)
return view
[docs]def compute_residue_score(classifier,reg,feats_info,n_residues):
"""Compute a residue score by aggregating all the features relevances by residues.
Parameters
----------
classifier : Classifier
classifier object
reg : float
regularization magnitude
feats_info : DataFrame
descriptors information
n_residues : int
number of residues
Returns
-------
dictionary
residue score per each state
"""
reg_idx = classifier._closest_reg_idx(reg)
coefficients = classifier._coeffs[reg_idx]
_classes = classifier._classes_labels[reg_idx]
residue_score = dict()
for idx, coef in enumerate(coefficients):
score = np.zeros(n_residues)
state_name = classifier.classes[_classes[idx]]
coef = coef**2
nrm = np.sum(coef)
coef = coef/nrm
if nrm < __EPS__:
pass
else:
indices = csr_matrix(coef/nrm).indices
for i in indices:
if classifier._quadratic_kernel:
raise NotImplementedError("Residue Score not implemented for quadratic features")
else:
feature_name = classifier.features[i]
resnames = feats_info[feature_name]['group'].split('_')
for res in resnames:
res_idx = int ( ''.join([n for n in res if n.isdigit()]) ) - 1
score[res_idx] += coef[i]/len(resnames)
residue_score[state_name] = score
return residue_score
[docs]def visualize_residue_score(trajectory, states_labels, classes_names, residue_score, representation = 'licorice', palette = 'Reds', state_frames=None, relevant_features = None,features_info=None):
"""Visualize snapshots of each state coloring the residues with the score per each state.
Parameters
----------
trajectory : mdtraj.Trajectory
MD trajectory
states_labels : pd.DataFrame
labels
classes_names : list
names of the classes
residue_score : dict
dictionary with the scores per each state
representation : str, optional
type of representation (licorice,cartoon,ball-and-stick), by default 'licorice'
palette : str, optional
color scheme, by default 'Reds'
Returns
-------
nglview viewer
View object
"""
# sample one frame per state
if state_frames is None:
frames = [states_labels [( states_labels['labels'] == i ) & ( states_labels['selection'] ) ].sample(1).index.values[0] for i in classes_names.values() ]
print('frames:', frames)
else:
frames = state_frames
traj = trajectory[frames]
traj.superpose(traj[0])
view = nglview.show_mdtraj(traj, default=False)
if representation == 'licorice' :
view.add_licorice('(not hydrogen)')
elif representation == 'cartoon' :
view.add_cartoon('protein')
elif representation == 'ball_and_stick' :
view.add_ball_and_stick('(not hydrogen)')
for score in residue_score.values():
for res,rescore in enumerate(score):
if rescore>0:
resnum = str(res+1)
resnum.zfill(3)
view.add_ball_and_stick(f'{resnum} and (not hydrogen) and',opacity=0.75)
view.add_ball_and_stick(f'{resnum}',opacity=0.5)
# highlight selected features
if relevant_features is not None:
atom_ids = []
#features = relevant_features[ classes_names[state] ]
features = relevant_features[ next(iter(relevant_features)) ]
for i, feature in enumerate(features):
name = feature[2]
print(name,features_info[name]['atoms'])
atom_ids.append( features_info[name]['atoms'] )
#colors = iter(['orange','green', 'purple', 'yellow', 'red'])
# loop over relevant features
for ids in atom_ids:
ids_string = [str(p) for p in ids]
selection = '@'+','.join(ids_string)
color = paletteFessa[1] #3#6 #'orange' #next(colors)
if len(ids) == 2: # distance
#color = 'orange'
atom_pair = [ '@'+p for p in ids_string ]
view.add_distance(atom_pair=[atom_pair], color=color, label_visible=False)
#view.add_ball_and_stick(selection,color=color,opacity=0.75)
elif len(ids) == 4: # angle
#color = 'green'
view.add_ball_and_stick(selection,color=color,opacity=0.75)
# get color palette
cmap = matplotlib.cm.get_cmap(palette, 11)
palette = [matplotlib.colors.rgb2hex( cmap(i) ) for i in range(cmap.N)]
# transform score in colors
residue_colors = {}
for i, state in enumerate( classes_names.values() ):
colors = []
for score in residue_score[ state ]:
col = int(score*5*10)
col = -1 if col > cmap.N-1 else col
#col = 0
colors.append( palette[col] )
residue_colors[i] = colors
# define observer function to allow changing colors with frame
def on_change(change):
frame = change.new
frame_color = residue_colors[frame]
frame_color = [c.replace('#', '0x') for c in frame_color]
view._set_color_by_residue(view,frame_color)
#view.update_licorice()
sleep(0.1) # wait for the color update
# convert to int
# initialize set color by residue
def _set_color_by_residue(self, colors, component_index=0, repr_index=0):
self._remote_call('setColorByResidue',
target='Widget',
args=[colors, component_index, repr_index])
if not hasattr(view, '_set_color_by_residue'):
view._set_color_by_residue = _set_color_by_residue
# set colors from state 0
frame_color = residue_colors[0]
frame_color = [c.replace('#', '0x') for c in frame_color]
view._set_color_by_residue(view,frame_color)
view.observe(on_change, names=['frame'])
return view
def visualize_protein_features(trajectory, states_labels, classes_names, residue_score, representation = 'licorice', state_frames=None, relevant_features = None, features_info=None, all_atoms=False, color=None):
"""Visualize snapshots of each state with the relevant features highlighted.
"""
# sample one frame per state
if state_frames is None:
frames = [states_labels [( states_labels['labels'] == i ) & ( states_labels['selection'] ) ].sample(1).index.values[0] for i in classes_names.values() ]
print('frames:', frames)
else:
frames = state_frames
traj = trajectory[frames]
traj.superpose(traj[0])
view = nglview.show_mdtraj(traj, default=False)
if representation == 'licorice' :
view.add_licorice('(not hydrogen)')
elif representation == 'cartoon' :
view.add_cartoon('protein')
elif representation == 'ball_and_stick' :
view.add_ball_and_stick('(not hydrogen) and (backbone)')
if all_atoms:
for score in residue_score.values():
for res,rescore in enumerate(score):
if rescore>0:
resnum = str(res+1)
resnum.zfill(3)
view.add_ball_and_stick(f'{resnum} and (not hydrogen) and',opacity=0.75)
#view.add_ball_and_stick(f'{resnum}',opacity=0.5)
# highlight selected features
if relevant_features is not None:
atom_ids = []
#features = relevant_features[ classes_names[state] ]
features = relevant_features[ next(iter(relevant_features)) ]
if features_info is not None:
for i, feature in enumerate(features):
name = feature[2]
print(name,features_info[name]['atoms'])
atom_ids.append( features_info[name]['atoms'] )
#colors = iter(['orange','green', 'purple', 'yellow', 'red'])
colors=iter(paletteFessa)
# loop over relevant features
for ids in atom_ids:
ids_string = [str(p) for p in ids]
selection = '@'+','.join(ids_string)
if color is None:
color = next(colors)
if len(ids) == 2: # distance
#color = next(colors)
atom_pair = [ '@'+p for p in ids_string ]
view.add_distance(atom_pair=[atom_pair], color=color, label_visible=False)
#view.add_ball_and_stick(selection,color=color,opacity=0.75)
elif len(ids) == 4: # angle
#color = 'green'
view.add_ball_and_stick(selection,color=color,opacity=0.75)
return view