"""TODO: Summary
"""
import os
import re
import functools as ft
import itertools as it
import operator
from hashlib import sha1
import numpy as np
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from concurrent.futures import as_completed as cf_as_completed
import multiprocessing
import threading
from homog import hrot
import pandas as pd
import numba as nb
try:
# god, I'm so tired of this crap....
from pyrosetta import rosetta as ros
HAVE_PYROSETTA = True
except ImportError:
HAVE_PYROSETTA = False
jit = nb.njit(nogil=True, fastmath=True)
[docs]@jit
def expand_array_if_needed(ary, i):
if len(ary) > i:
return ary
newshape = (ary.shape[0] * 2, ) + ary.shape[1:]
new = np.zeros(newshape, dtype=ary.dtype) - ary.dtype.type(1)
new[:len(ary)] = ary
return new
[docs]class InProcessExecutor:
"""TODO: Summary
"""
def __init__(self, *args, **kw):
"""TODO: Summary
Args:
*args: Description
kw: passthru args """
pass
def __enter__(self):
"""TODO: Summary
Returns:
TYPE: Description
"""
return self
def __exit__(self, *args):
"""TODO: Summary
Args:
*args: Description
"""
pass
[docs] def submit(self, fn, *args, **kw):
"""TODO: Summary
Args:
fn (TYPE): Description
args: Description
kw: passthru args
Returns:
TYPE: Description
"""
return NonFuture(fn(*args, **kw))
[docs] def map(self, func, *iterables):
"""TODO: Summary
Args:
func (TYPE): Description
iterables: Description
Returns:
TYPE: Description
"""
return map(func, *iterables)
# return (NonFuture(func(*args) for args in zip(iterables)))
[docs]class NonFuture:
"""TODO: Summary
"""
def __init__(self, result):
"""TODO: Summary
Args:
result (TYPE): Description
"""
self._result = result
self._condition = threading.Condition()
self._state = 'FINISHED'
self._waiters = []
[docs] def result(self):
"""TODO: Summary
Returns:
TYPE: Description
"""
return self._result
[docs]def cpu_count():
"""TODO: Summary
Returns:
TYPE: Description
"""
try:
return int(os.environ['SLURM_CPUS_ON_NODE'])
except:
return multiprocessing.cpu_count()
[docs]def parallel_batch_map(
pool, function, accumulator, batch_size, map_func_args, **kw
):
"""TODO: Summary
Args:
pool (TYPE): Description
function (TYPE): Description
accumulator (TYPE): Description
batch_size (TYPE): Description
map_func_args (TYPE): Description
kw: passthru args
Yields:
TYPE: Description
"""
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
njobs = len(map_func_args[0])
args = list(zip(*map_func_args))
for ibatch in range(0, njobs, batch_size):
beg = ibatch
end = min(njobs, ibatch + batch_size)
batch_args = args[beg:end] # todo, this could be done lazily...
futures = [pool.submit(function, *a) for a in batch_args]
if isinstance(pool, (ProcessPoolExecutor, ThreadPoolExecutor)):
as_completed = cf_as_completed
elif isinstance(pool, InProcessExecutor):
as_completed = lambda x: x
else:
from dask.distributed import as_completed as dd_as_completed
as_completed = dd_as_completed
for _ in accumulator.accumulate(as_completed(futures)):
yield None
accumulator.checkpoint()
[docs]def parallel_nobatch_map(
pool, function, accumulator, batch_size, map_func_args, **kw
):
"""TODO: Summary
Args:
pool (TYPE): Description
function (TYPE): Description
accumulator (TYPE): Description
batch_size (TYPE): Description
map_func_args (TYPE): Description
kw: passthru args
Yields:
TYPE: Description
"""
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
njobs = len(map_func_args[0])
args = list(zip(*map_func_args))
futures = [pool.submit(function, *a) for a in args]
if isinstance(pool, (ProcessPoolExecutor, ThreadPoolExecutor)):
as_completed = cf_as_completed
else:
as_completed = dd_as_completed
for _ in accumulator.accumulate(as_completed(futures)):
yield None
accumulator.checkpoint()
[docs]def tqdm_parallel_map(
pool, function, accumulator, map_func_args, batch_size, **kw
):
"""TODO: Summary
Args:
pool (TYPE): Description
function (TYPE): Description
accumulator (TYPE): Description
map_func_args (TYPE): Description
batch_size (TYPE): Description
kw: passthru args """
for _ in tqdm(parallel_batch_map(pool, function, accumulator, batch_size,
map_func_args=map_func_args, **kw),
total=len(map_func_args[0]), **kw):
pass
[docs]def numpy_stub_from_rosetta_stub(rosstub):
"""TODO: Summary
Args:
rosstub (TYPE): Description
Returns:
TYPE: Description
"""
npstub = np.zeros((4, 4))
for i in range(3):
npstub[..., i, 3] = rosstub.v[i]
for j in range(3):
npstub[..., i, j] = rosstub.M(i + 1, j + 1)
npstub[..., 3, 3] = 1.0
return npstub
[docs]def rosetta_stub_from_numpy_stub(npstub):
"""TODO: Summary
Args:
npstub (TYPE): Description
Returns:
TYPE: Description
"""
rosstub = ros.core.kinematics.Stub()
rosstub.M.xx = npstub[0, 0]
rosstub.M.xy = npstub[0, 1]
rosstub.M.xz = npstub[0, 2]
rosstub.M.yx = npstub[1, 0]
rosstub.M.yy = npstub[1, 1]
rosstub.M.yz = npstub[1, 2]
rosstub.M.zx = npstub[2, 0]
rosstub.M.zy = npstub[2, 1]
rosstub.M.zz = npstub[2, 2]
rosstub.v.x = npstub[0, 3]
rosstub.v.y = npstub[1, 3]
rosstub.v.z = npstub[2, 3]
return rosstub
[docs]def get_bb_stubs(pose, which_resi=None):
"""extract rif style stubs from rosetta pose
Args:
pose (TYPE): Description
which_resi (None, optional): Description
Returns:
TYPE: Description
Raises:
ValueError: Description
"""
if which_resi is None:
which_resi = list(range(1, pose.size() + 1))
npstubs, n_ca_c = [], []
for ir in which_resi:
r = pose.residue(ir)
if not r.is_protein():
raise ValueError(
'non-protein residue %s at position %i' % (r.name(), ir)
)
n, ca, c = r.xyz('N'), r.xyz('CA'), r.xyz('C')
ros_stub = ros.core.kinematics.Stub(ca, n, ca, c)
npstubs.append(numpy_stub_from_rosetta_stub(ros_stub))
n_ca_c.append(
np.array([[n.x, n.y, n.z],
[ca.x, ca.y, ca.z],
[c.x, c.y, c.z]])
)
return np.stack(npstubs).astype('f8'), np.stack(n_ca_c).astype('f8')
[docs]def get_bb_coords(pose, which_resi=None):
"""extract rif style stubs from rosetta pose
Args:
pose (TYPE): Description
which_resi (None, optional): Description
Returns:
TYPE: Description
Raises:
ValueError: Description
"""
if which_resi is None:
which_resi = list(range(1, pose.size() + 1))
n_ca_c = []
for ir in which_resi:
r = pose.residue(ir)
if not r.is_protein():
raise ValueError(
'non-protein residue %s at position %i' % (r.name(), ir)
)
n, ca, c = r.xyz('N'), r.xyz('CA'), r.xyz('C')
n_ca_c.append(
np.array([
[n.x, n.y, n.z, 1],
[ca.x, ca.y, ca.z, 1],
[c.x, c.y, c.z, 1],
])
)
return np.stack(n_ca_c).astype('f8')
[docs]def get_chain_bounds(pose):
"""TODO: Summary
Args:
pose (TYPE): Description
Returns:
TYPE: Description
"""
ch = np.array([pose.chain(i + 1) for i in range(len(pose))])
chains = list()
for i in range(ch[-1]):
chains.append((np.sum(ch <= i), np.sum(ch <= i + 1)))
assert chains[0][0] == 0
assert chains[-1][-1] == len(pose)
return chains
[docs]def pose_bounds(pose, lb, ub):
"""TODO: Summary
Args:
pose (TYPE): Description
lb (TYPE): Description
ub (TYPE): Description
Returns:
TYPE: Description
Raises:
ValueError: Description
"""
if ub < 0: ub = len(pose) + 1 + ub
if lb < 1 or ub > len(pose):
raise ValueError(
'lb/ub ' + str(lb) + '/' + str(ub) +
' out of bounds for pose with len ' + str(len(pose))
)
return lb, ub
[docs]def subpose(pose, lb, ub=-1):
"""TODO: Summary
Args:
pose (TYPE): Description
lb (TYPE): Description
ub (TYPE, optional): Description
Returns:
TYPE: Description
"""
lb, ub = pose_bounds(pose, lb, ub)
p = ros.core.pose.Pose()
ros.core.pose.append_subpose_to_pose(p, pose, lb, ub)
return p
[docs]def worst_CN_connect(p):
"""TODO: Summary
Args:
p (TYPE): Description
Returns:
TYPE: Description
"""
for ir in range(1, len(p)):
worst = 0
if (p.residue(ir).is_protein() and p.residue(ir + 1).is_protein()
and not (ros.core.pose.is_upper_terminus(p, ir)
or ros.core.pose.is_lower_terminus(p, ir + 1))):
dist = p.residue(ir).xyz('C').distance(p.residue(ir + 1).xyz('N'))
worst = max(abs(dist - 1.32), worst)
return worst
[docs]def no_overlapping_adjacent_residues(p):
"""TODO: Summary
Args:
p (TYPE): Description
Returns:
TYPE: Description
"""
for ir in range(1, len(p)):
if (p.residue(ir).is_protein() and p.residue(ir + 1).is_protein()):
dist = p.residue(ir).xyz('CA').distance(
p.residue(ir + 1).xyz('CA')
)
if dist < 0.1: return False
return True
[docs]def no_overlapping_residues(p):
"""TODO: Summary
Args:
p (TYPE): Description
Returns:
TYPE: Description
"""
for ir in range(1, len(p) + 1):
if not p.residue(ir).is_protein():
continue
for jr in range(1, ir):
if not p.residue(jr).is_protein():
continue
dist = p.residue(ir).xyz('CA').distance(p.residue(jr).xyz('CA'))
if dist < 0.5: return False
return True
[docs]def trim_pose(pose, resid, direction, pad=0):
"""trim end of pose from direction, leaving <=pad residues beyond resid
Args:
pose (TYPE): Description
resid (TYPE): Description
direction (TYPE): Description
pad (int, optional): Description
Returns:
TYPE: Description
Raises:
ValueError: Description
"""
if direction not in "NC":
raise ValueError("direction must be 'N' or 'C'")
if not 0 < resid <= len(pose):
raise ValueError("resid %i out of bounds %i" % (resid, len(pose)))
p = ros.core.pose.Pose()
if direction == 'N':
lb, ub = max(resid - pad, 1), len(pose)
elif direction == 'C':
lb, ub = 1, min(resid + pad, len(pose))
# print('_trim_pose lbub', lb, ub, 'len', len(pose), 'resid', resid)
ros.core.pose.append_subpose_to_pose(p, pose, lb, ub)
return p, lb, ub
[docs]def symfile_path(name):
"""TODO: Summary
Args:
name (TYPE): Description
Returns:
TYPE: Description
"""
path, _ = os.path.split(__file__)
return os.path.join(path, 'rosetta_symdef', name + '.sym')
[docs]@ft.lru_cache()
def get_symfile_contents(name):
"""TODO: Summary
Args:
name (TYPE): Description
Returns:
TYPE: Description
"""
with open(symfile_path(name)) as f:
return f.read()
[docs]@ft.lru_cache()
def get_symdata(name):
"""TODO: Summary
Args:
name (TYPE): Description
Returns:
TYPE: Description
"""
if name is None: return None
ss = ros.std.stringstream(get_symfile_contents(name))
d = ros.core.conformation.symmetry.SymmData()
d.read_symmetry_data_from_stream(ss)
return d
[docs]def get_symdata_modified(
name, string_substitutions=None, scale_positions=None
):
"""TODO: Summary
Args:
name (TYPE): Description
string_substitutions (None, optional): Description
scale_positions (None, optional): Description
Returns:
TYPE: Description
"""
if name is None: return None
symfilestr = get_symfile_contents(name)
if scale_positions is not None:
if string_substitutions is None:
string_substitutions = dict()
for line in symfilestr.splitlines():
if not line.startswith('xyz'): continue
posstr = re.split('\s+', line)[-1]
x, y, z = [float(x) * scale_positions for x in posstr.split(',')]
string_substitutions[posstr] = '%f,%f,%f' % (x, y, z)
if string_substitutions is not None:
for k, v in string_substitutions.items():
symfilestr = symfilestr.replace(k, v)
ss = ros.std.stringstream(symfilestr)
d = ros.core.conformation.symmetry.SymmData()
d.read_symmetry_data_from_stream(ss)
return d
[docs]def infer_cyclic_symmetry(pose):
"""TODO: Summary
Args:
pose (TYPE): Description
Raises:
NotImplementedError: Description
"""
raise NotImplementedError
[docs]def bigprod(iterable):
"""TODO: Summary
Args:
iterable (TYPE): Description
Returns:
TYPE: Description
"""
return ft.reduce(operator.mul, iterable, 1)
[docs]class MultiRange:
"""TODO: Summary
Attributes:
len (TYPE): Description
nside (TYPE): Description
psum (TYPE): Description
"""
def __init__(self, nside):
"""TODO: Summary
Args:
nside (TYPE): Description
"""
self.nside = np.array(nside, dtype='i')
self.psum = np.concatenate([
np.cumprod(self.nside[1:][::-1])[::-1], [1]
])
assert np.all(self.psum > 0)
assert bigprod(self.nside[1:]) < 2**63
self.len = bigprod(self.nside)
def __getitem__(self, idx):
"""TODO: Summary
Args:
idx (TYPE): Description
Returns:
TYPE: Description
Raises:
StopIteration: Description
"""
if isinstance(idx, slice):
return (self[i] for i in range(self.len)[idx])
if idx >= self.len:
raise StopIteration
return tuple((idx // self.psum) % self.nside)
def __len__(self):
"""TODO: Summary
Returns:
TYPE: Description
"""
return self.len
[docs]def first_duplicate(segs):
"""TODO: Summary
Args:
segs (TYPE): Description
Returns:
TYPE: Description
"""
for i in range(len(segs) - 1, 0, -1):
for j in range(i):
if segs[i].spliceables == segs[j].spliceables:
return j
return None
[docs]def dicts_to_items(inp):
"""TODO: Summary
Args:
inp (TYPE): Description
Returns:
TYPE: Description
"""
if isinstance(inp, list):
return [dicts_to_items(x) for x in inp]
elif isinstance(inp, dict):
return [(k, dicts_to_items(v)) for k, v in inp.items()]
return inp
[docs]def items_to_dicts(inp):
"""TODO: Summary
Args:
inp (TYPE): Description
Returns:
TYPE: Description
"""
if (isinstance(inp, list) and isinstance(inp[0], tuple)
and len(inp[0]) is 2):
return {k: items_to_dicts(v) for k, v in inp}
elif isinstance(inp, list):
return [items_to_dicts(x) for x in inp]
return inp
[docs]def residue_coords(p, ir, n=3):
crd = (p.residue(ir).xyz(i) for i in range(1, n + 1))
return np.stack([c.x, c.y, c.z, 1] for c in crd)
[docs]def residue_sym_err(p, ang, ir, jr, n=1, axis=[0, 0, 1], verbose=0):
mxdist = 0
for i in range(n):
xyz0 = residue_coords(p, ir + i)
xyz1 = residue_coords(p, jr + i)
xyz3 = hrot(axis, ang) @ xyz1.T
xyz4 = hrot(axis, -ang) @ xyz1.T
if verbose:
print(i, xyz0)
print(i, xyz1)
print(i, xyz3.T)
print(i, xyz4.T)
print()
mxdist = max(
mxdist,
min(
np.max(np.sum((xyz0 - xyz3.T)**2, axis=1)),
np.max(np.sum((xyz0 - xyz4.T)**2, axis=1))
)
)
return np.sqrt(mxdist)
[docs]def unique_key(a, b=None):
if b is None:
raise NotImplementedError
mi = pd.MultiIndex.from_arrays([a, b]).drop_duplicates()
return mi.get_indexer([a, b])
@jit
def _unique_key_int32s(keys):
map = -np.ones(np.max(keys) + 1, dtype=np.int32)
count = 0
for k in keys:
if map[k] < 0:
map[k] = count
count += 1
out = np.empty(len(keys), dtype=np.int32)
for i in range(len(keys)):
out[i] = map[keys[i]]
return out
[docs]def unique_key_int32s(a, b):
if b[0] == -1:
assert np.all(b == -1)
return a
a = a.astype('i8')
b = b.astype('i8')
m = np.max(a) + 1
k = b * m + a
assert np.all(k >= 0)
return _unique_key_int32s(k)
[docs]@nb.njit('int32[:](int32[:])', nogil=1)
def contig_idx_breaks(idx):
breaks = np.empty(idx[-1] + 2, dtype=np.int32)
breaks[0] = 0
n = 1
for i in range(1, len(idx)):
if idx[i - 1] != idx[i]:
assert idx[i - 1] < idx[i]
breaks[n] = i
n += 1
breaks[n] = len(idx)
breaks = np.ascontiguousarray(breaks[:n + 1])
if __debug__:
for i in range(breaks.size - 1):
vals = idx[breaks[i]:breaks[i + 1]]
assert len(vals)
assert np.all(vals == vals[0])
return breaks
[docs]def hash_str_to_int(s):
if isinstance(s, str): s = s.encode()
buf = sha1(s).digest()[:8]
return int(abs(np.frombuffer(buf, dtype='i8')[0]))