Source code for stateinterpreter.descriptors

# imports
import numpy as np
import pandas as pd
import mdtraj as md
import itertools
from warnings import warn

from .utils.io import load_dataframe
from ._configs import *
from .utils._compiled_numerics import contact_function

__all__ = ["compute_descriptors", "load_descriptors"]

[docs]def compute_descriptors(traj, descriptors = ['ca', 'dihedrals', 'hbonds_distances', 'hbonds_contacts']): """Compute descriptors from trajectory: - Dihedral angles - CA distances - Hydrogen bonds distances - Hydrogen bonds contacts - Disulfide bonds dihedrals Parameters ---------- descriptors : bool or list compute list of descriptors. by default compute all the following descriptors: ['ca', 'dihedrals', 'hbonds_distances', 'hbonds_contacts']. Raises ------ KeyError Trajectory needs to be set beforehand. """ descr_list = [] #Cache dists for performance _cached_dists = False if ('hbonds_distances' in descriptors) and ('hbonds_contacts' in descriptors): _cached_dists = True dist_idx = descriptors.index('hbonds_distances') cont_idx = descriptors.index('hbonds_contacts') #swap elements so that hb_distances are calculated first and then cached if cont_idx < dist_idx: get = descriptors[cont_idx], descriptors[dist_idx] descriptors[dist_idx], descriptors[cont_idx] = get _raw_data = [] _feats = [] _feats_info = {} for d in descriptors: if d == 'ca': res, names, descriptors_ids = _CA_DISTANCES(traj) elif d == 'dihedrals': for angle in ['phi', 'psi', 'chi1', 'chi2']: res, names, descriptors_ids = _DIHEDRALS(traj, kind=angle, sincos=True) _raw_data.append(res) _feats.extend(names) _feats_info.update(descriptors_ids) elif d == 'hbonds_distances': res, names, descriptors_ids = _HYDROGEN_BONDS(traj, 'distances') if _cached_dists: _dsts = np.copy(res) elif d == 'hbonds_contacts': if _cached_dists: res, names, descriptors_ids = _HYDROGEN_BONDS(traj, 'contacts', _cached_dists=_dsts) else: res, names, descriptors_ids = _HYDROGEN_BONDS(traj, 'contacts') elif d == 'disulfide': res, names, descriptors_ids = _DISULFIDE_DIHEDRALS(traj, sincos=True) print(names) print(descriptors_ids) else: raise KeyError(f"descriptor: {d} not valid. Only 'ca', 'dihedrals', 'hbonds_distances', 'hbonds_contacts','disulfide' are allowed.") if d != 'dihedrals': #(Done previously) _raw_data.append(res) _feats.extend(names) _feats_info.update(descriptors_ids) df = pd.DataFrame(np.hstack(_raw_data), columns=_feats) if __DEV__: print(f"Descriptors: {df.shape}") return df, _feats_info
[docs]def load_descriptors(descriptors, start = 0, stop = None, stride = 1, **kwargs): descriptors = load_dataframe(descriptors, **kwargs) descriptors = descriptors.iloc[start:stop:stride, :] if "time" in descriptors.columns: descriptors = descriptors.drop("time", axis="columns") if __DEV__: print(f"Descriptors: {descriptors.shape}") return descriptors
# DESCRIPTORS COMPUTATION def _CA_DISTANCES(traj): descriptors_ids = {} if __DEV__: print(f"Computing CA distances") table, _ = traj.top.to_dataframe() sel = traj.top.select("name CA") pairs = [(i, j) for i, j in itertools.combinations(sel, 2)] dist = md.compute_distances(traj, np.array(pairs, dtype=int)) # Labels label = lambda i, j: "DIST %s%s -- %s%s" % ( traj.top.atom(i), "s" if traj.top.atom(i).is_sidechain else "", traj.top.atom(j), "s" if traj.top.atom(j).is_sidechain else "", ) names = [label(i, j) for (i, j) in pairs] for (i,j) in pairs: res_i = table["resName"][i] + table["resSeq"][i].astype("str") res_j = table["resName"][j] + table["resSeq"][j].astype("str") info = { 'atoms': [i,j], 'group': res_i + "_" + res_j } descriptors_ids[label(i, j)] = info return dist, names, descriptors_ids def _HYDROGEN_BONDS(traj, kind, _cached_dists = None): # H-BONDS DISTANCES / CONTACTS (donor-acceptor) # find donors (OH or NH) if __DEV__: print(f"Computing Hydrogen bonds {kind}") table, _ = traj.top.to_dataframe() donors = [ at_i.index for at_i, at_j in traj.top.bonds if ((at_i.element.symbol == "O") | (at_i.element.symbol == "N")) & (at_j.element.symbol == "H") ] # Keep unique donors = sorted(list(set(donors))) if __DEV__: print("Donors:", donors) # Find acceptors (O r N) acceptors = traj.top.select("symbol O or symbol N") #hbonded = [ at_i.index # for at_i, at_j in traj.top.bonds # if (at_j.element.symbol == "H") #] #acceptors = [idx # for idx in traj.top.select("symbol O or symbol N") # if idx not in hbonded # ] if __DEV__: print("Acceptors:", acceptors) # lambda func to avoid selecting interaction within the same residue atom_residue = lambda i: str(traj.top.atom(i)).split("-")[0] # compute pairs pairs = [ (min(x, y), max(x, y)) for x in donors for y in acceptors if (x != y) and (atom_residue(x) != atom_residue(y)) ] # remove duplicates pairs = sorted(list(set(pairs))) # compute distances if _cached_dists is None: dist = md.compute_distances(traj, pairs) else: dist = _cached_dists if kind == 'distances': descriptors_ids = {} # labels label = lambda i, j: "HB_DIST %s%s -- %s%s" % ( traj.top.atom(i), "s" if traj.top.atom(i).is_sidechain else "", traj.top.atom(j), "s" if traj.top.atom(j).is_sidechain else "", ) # basename = 'hb_' # names = [ basename+str(x)+'-'+str(y) for x,y in pairs] names = [label(x, y) for x, y in pairs] for (x,y) in pairs: res_x = table["resName"][x] + table["resSeq"][x].astype("str") res_y = table["resName"][y] + table["resSeq"][y].astype("str") info = { 'atoms': [x,y], 'group': res_x + "_" + res_y } descriptors_ids[label(x,y)] = info return dist, names, descriptors_ids elif kind == 'contacts': descriptors_ids = {} # Compute contacts contacts = contact_function(dist, r0=0.35, d0=0, n=6, m=12) # labels # basename = 'hbc_' # names = [ basename+str(x)+'-'+str(y) for x,y in pairs] label = lambda i, j: "HB_C %s%s -- %s%s" % ( traj.top.atom(i), "s" if traj.top.atom(i).is_sidechain else "", traj.top.atom(j), "s" if traj.top.atom(j).is_sidechain else "", ) names = [label(x, y) for x, y in pairs] for (x,y) in pairs: res_x = table["resName"][x] + table["resSeq"][x].astype("str") res_y = table["resName"][y] + table["resSeq"][y].astype("str") info = { 'atoms': [x,y], 'group': res_x + "_" + res_y } descriptors_ids[label(x,y)] = info return contacts, names, descriptors_ids else: raise KeyError(f'kind="{kind}" not allowed. Valid values: "distances","contacts".') def _DIHEDRALS(traj, kind, sincos=True): # Get topology table, _ = traj.top.to_dataframe() if kind == "phi": dih_idxs, angles = md.compute_phi(traj) elif kind == "psi": dih_idxs, angles = md.compute_psi(traj) elif kind == "chi1": dih_idxs, angles = md.compute_chi1(traj) elif kind == "chi2": dih_idxs, angles = md.compute_chi2(traj) else: raise KeyError(f'kind="{kind}" not allowed. Supported values: "phi", "psi", "chi1", "chi2"') names = [] if sincos: sin_names = [] cos_names = [] descriptors_ids = {} for i, idx in enumerate(dih_idxs): # find residue id from topology table # res = table['resSeq'][idx[0]] # name = 'dih_'+kind+'-'+str(res) res = table["resName"][idx[0]] + table["resSeq"][idx[0]].astype("str") #name = "BACKBONE " + kind + " " + res name = kind + " " + res #if "chi" in kind: # name = "SIDECHAIN " + kind + " " + res names.append(name) info = { 'atoms': list(idx), 'group': res } descriptors_ids[name] = info if sincos: for trig_transform in (np.sin, np.cos): _trans_name = trig_transform.__name__ + "_" # names.append('cos_(sin_)'+kind+'-'+str(res)) #name = "BACKBONE " + _trans_name + kind + " " + res name = _trans_name + kind + " " + res #if "chi" in kind: # name = "SIDECHAIN " + _trans_name + kind + " " + res #Dirty trick eval(_trans_name + "names.append(name)") descriptors_ids[name] = info if sincos: angles = np.hstack([angles, np.sin(angles), np.cos(angles)]) names = names + sin_names + cos_names return angles, names, descriptors_ids def _DISULFIDE_DIHEDRALS(traj, sincos = True): table, bonds = traj.top.to_dataframe() # filter S atoms belonging to CYS s_cys = table[ (table['element'] == 'S') & (table['resName'] == 'CYS') ].index # define arrays names = [] angles = [] descriptors_ids = {} # Loop over every pair of S atoms for i,j in itertools.combinations(s_cys,2): # Check if bond is formed d_ij = md.compute_distances(traj[0],[[i,j]])[0][0] if d_ij < 0.25: # look for C atoms bonded with S for k,l,_,_ in bonds: if int(k) == i: c_i = int(l) elif int(l) == i: c_i = int(k) if int(k) == j: c_j = int(l) elif int(l) == j: c_j = int(k) # compute feature group = str(traj.top.atom(i).residue)+'_'+str(traj.top.atom(j).residue) desc = md.compute_dihedrals(traj,[[c_i,i,j,c_j]])[:,0] name = 'DISULFIDE dih '+group descriptors_ids[name] = {'atoms': [c_i,i,j,c_j], 'group' : group} angles.append(desc) names.append(name) if sincos: name = 'DISULFIDE sin_dih '+group descriptors_ids[name] = {'atoms': [c_i,i,j,c_j], 'group' : group} angles.append(np.sin(desc)) names.append(name) name = 'DISULFIDE cos_dih '+group descriptors_ids[name] = {'atoms': [c_i,i,j,c_j], 'group' : group} angles.append(np.cos(desc)) names.append(name) angles = np.asarray(angles).T return angles, names, descriptors_ids