Source code for worms.edge

import sys
from time import time
import random
import numpy as np
import numba as nb
import numba.types as nt
from collections import defaultdict, namedtuple
from worms.util import contig_idx_breaks, jit, InProcessExecutor, NonFuture
import concurrent.futures as cf
from tqdm import tqdm
from worms.database import SpliceDB

try:
    # this is such bullshit...
    from pyrosetta import pose_from_file
    from pyrosetta.rosetta.core.scoring.dssp import Dssp
    HAVE_PYROSETTA = True
except ImportError:
    HAVE_PYROSETTA = False


[docs]def Edge(u, ublks, v, vblks, verbosity=0, **kw): splices, nout, nent = get_allowed_splices( u, ublks, v, vblks, verbosity=verbosity, **kw ) maxentries = max(len(_) for _ in splices) splice_ary = np.zeros((len(splices), maxentries + 1), dtype=np.int32) - 1 for i, a in enumerate(splice_ary): a[0] = len(splices[i]) + 1 a[1:a[0]] = sorted(splices[i]) assert np.max(splice_ary[:, 1:]) < len(v.inbreaks), \ 'egde.py bad splice_ary' assert len(splice_ary) == 1 + np.max(u.inout[:, 1]), \ 'edge.py, bad splice_ary' return _Edge(splice_ary, nout, nent)
[docs]def splice_metrics_pair( blk0, blk1, max_splice_rms=0.7, clashd2=3.0**2, contactd2=10.0**2, rms_range=9, clash_contact_range=9, skip_on_fail=True, ): return _jit_splice_metrics( blk0.chains, blk1.chains, blk0.ncac, blk1.ncac, blk0.stubs, blk1.stubs, blk0.connections, blk1.connections, clashd2, contactd2, rms_range, clash_contact_range, max_splice_rms, skip_on_fail )
[docs]def get_allowed_splices( u, ublks, v, vblks, splicedb=None, max_splice_rms=0.7, ncontact_cut=10, clashd2=3.0**2, contactd2=10.0**2, rms_range=5, clash_contact_range=9, skip_on_fail=True, parallel=False, verbosity=1, sync_to_disk_every=0.001, ): assert (u.dirn[1] + v.dirn[0]) == 1, 'get_allowed_splices dirn mismatch' params = ( max_splice_rms, ncontact_cut, clashd2, contactd2, rms_range, clash_contact_range, u.min_seg_len, v.min_seg_len ) outidx = _get_outidx(u.inout[:, 1]) outblk = u.ibblock[outidx] outres = u.ires[outidx, 1] inblk = v.ibblock[v.inbreaks[:-1]] inres = v.ires[v.inbreaks[:-1], 0] inblk_breaks = contig_idx_breaks(inblk) outblk_res = defaultdict(list) for iblk, ires in zip(outblk, outres): outblk_res[iblk].append(ires) for iblk in outblk_res.keys(): outblk_res[iblk] = np.array(outblk_res[iblk], 'i4') inblk_res = defaultdict(list) for iblk, ires in zip(inblk, inres): inblk_res[iblk].append(ires) for iblk in inblk_res.keys(): inblk_res[iblk] = np.array(inblk_res[iblk], 'i4') assert np.all(sorted(inblk_res[iblk]) == inblk_res[iblk]) nout = sum(len(a) for a in outblk_res.values()) nent = sum(len(a) for a in inblk_res.values()) valid_splices = [list() for i in range(nout)] swapped = False if u.dirn[1] == 0: # swap so N-to-C! swapped = True u, ublks, v, vblks = v, vblks, u, ublks outblk_res, inblk_res = inblk_res, outblk_res outblk, inblk = inblk, outblk pairs_with_no_valid_splices = 0 tcache = 0 exe = cf.ProcessPoolExecutor if parallel else InProcessExecutor with exe() as pool: futures = list() ofst0 = 0 for iblk0, ires0 in outblk_res.items(): blk0 = ublks[iblk0] key0 = blk0.filehash t = time() cache = splicedb.partial(params, key0) if splicedb else None tcache += time() - t ofst1 = 0 for iblk1, ires1 in inblk_res.items(): blk1 = vblks[iblk1] key1 = blk1.filehash if cache and key1 in cache and cache[key1]: splices = cache[key1] future = NonFuture(splices) else: future = pool.submit( _jit_splice_metrics, blk0.chains, blk1.chains, blk0.ncac, blk1.ncac, blk0.stubs, blk1.stubs, blk0.connections, blk1.connections, clashd2, contactd2, rms_range, clash_contact_range, max_splice_rms, skip_on_fail ) fs = (iblk0, iblk1, ofst0, ofst1, ires0, ires1) future.stash = fs futures.append(future) ofst1 += len(ires1) ofst0 += len(ires0) if verbosity > 0 and tcache > 1.0: print('get_allowed_splices read caches time:', tcache) future_iter = cf.as_completed(futures) if verbosity > 1: future_iter = tqdm( cf.as_completed(futures), 'checking splices', total=len(futures), smoothing=0.0 # does this do anything? ) for future in future_iter: iblk0, iblk1, ofst0, ofst1, ires0, ires1 = future.stash result = future.result() if len(result) is 3 and isinstance(result[0], np.ndarray): rms, nclash, ncontact = result ok = ((nclash == 0) * (rms <= max_splice_rms) * (ncontact >= ncontact_cut)) result = _splice_respairs(ok, ublks[iblk0], vblks[iblk1]) if splicedb: key0 = ublks[iblk0].filehash # C-term side key1 = vblks[iblk1].filehash # N-term side splicedb.add(params, key0, key1, result) if np.random.random() < sync_to_disk_every: print('sync_to_disk splices data') splicedb.sync_to_disk() if swapped: result = result[1], result[0] ires0, ires1 = ires1, ires0 ofst0, ofst1 = ofst1, ofst0 if len(result[0]) == 0: pairs_with_no_valid_splices += 1 continue index_of_ires0 = _index_of_map(ires0, np.max(result[0])) index_of_ires1 = _index_of_map(ires1, np.max(result[1])) irs = index_of_ires0[result[0]] + ofst0 jrs = index_of_ires1[result[1]] + ofst1 for ir, jr in zip(irs, jrs): valid_splices[ir].append(jr) if sync_to_disk_every > 0 and splicedb: splicedb.sync_to_disk() if pairs_with_no_valid_splices: print( 'pairs with no valid splices: ', pairs_with_no_valid_splices, 'of', len(outblk_res) * len(inblk_res) ) return valid_splices, nout, nent
@nb.jitclass(( ('splices', nt.int32[:, :]), ('nout' , nt.int32), ('nent' , nt.int32), )) # yapf: disable class _Edge: """contains junction scores """ def __init__(self, splices, nout, nent): self.splices = splices self.nout = nout self.nent = nent @property def len(self): return len(self.splices) def allowed_entries(self, i): assert i >= 0, 'edge.py allowed_entries bad i' assert self.splices.shape[0] > i, 'edge.py allowed_entries bad i' assert self.splices.shape[1] >= self.splices[i, 0], \ 'edge.py allowed_entries bad i' return self.splices[i, 1:self.splices[i, 0]] def total_allowed_splices(self): return np.sum(self.splices[:, 0]) - len(self.splices) @property def _state(self): return (self.splices, self.nout, self.nent) @jit def _chainbounds_of_ires(chains, ires): for c in range(len(chains)): if chains[c, 0] <= ires < chains[c, 1]: return chains[c, 0], chains[c, 1] return (-1, -1) @jit def _ires_from_conn(conn, dirn): n = 0 for i in range(len(conn)): if conn[i, 0] == dirn: n += conn[i, 1] - 2 ires = np.empty(n, dtype=np.int32) pos = 0 for i in range(len(conn)): if conn[i, 0] == dirn: ires[pos:pos + conn[i, 1] - 2] = conn[i, 2:conn[i, 1]] pos += conn[i, 1] - 2 assert pos == n return ires @jit def _index_of_map(ary, mx): map = -np.ones(mx + 1, dtype=np.int32) for i, v in enumerate(ary): if v <= mx: map[v] = i return map @jit def _jit_splice_metrics(chains0, chains1, ncac0_3d, ncac1_3d, stubs0, stubs1, conn0, conn1, clashd2=3.0**2, contactd2=10.0**2, rms_range=9, clash_contact_range=9, max_splice_rms=1.1, skip_on_fail=True): # yapf: disable aln0s = _ires_from_conn(conn0, 1) aln1s = _ires_from_conn(conn1, 0) out_rms = np.zeros((len(aln0s), len(aln1s)), dtype=np.float32) out_nclash = -np.ones((len(aln0s), len(aln1s)), dtype=np.float32) out_ncontact = -np.ones((len(aln0s), len(aln1s)), dtype=np.float32) ncac0 = ncac0_3d.reshape(-1, 4) ncac1 = ncac1_3d.reshape(-1, 4) b = np.empty((4, ), dtype=np.float64) for ialn1, aln1 in enumerate(aln1s): chainb10, chainb11 = _chainbounds_of_ires(chains1, aln1) if np.abs(chainb10 - aln1) < rms_range: continue if np.abs(chainb11 - aln1) <= rms_range: continue stub1_inv = np.linalg.inv(stubs1[aln1]) for ialn0, aln0 in enumerate(aln0s): chainb00, chainb01 = _chainbounds_of_ires(chains0, aln0) if np.abs(chainb00 - aln0) < rms_range: continue if np.abs(chainb01 - aln0) <= rms_range: continue xaln = stubs0[aln0] @ stub1_inv sum_d2, n1b = 0.0, 0 for i in range(-3 * rms_range, 3 * rms_range + 3): a = ncac0[3 * aln0 + i] b[:] = xaln @ ncac1[3 * aln1 + i] sum_d2 += np.sum((a - b)**2) rms = np.sqrt(sum_d2 / (rms_range * 6 + 3)) assert 0 <= rms < 9e9, 'bad rms' out_rms[ialn0, ialn1] = rms if skip_on_fail and rms > max_splice_rms: continue nclash, ncontact = 0, 0 for j in range(3, 3 * clash_contact_range + 3): b[:] = xaln @ ncac1[3 * aln1 + j] for i in range(-1, -3 * clash_contact_range - 1, -1): a = ncac0[3 * aln0 + i] d2 = np.sum((a - b)**2) if d2 < clashd2: nclash += 1 elif i % 3 == 1 and j % 3 == 1 and d2 < contactd2: ncontact += 1 assert 0 <= np.isnan(nclash) < 99999, 'bad nclash' assert 0 <= np.isnan(ncontact) < 99999, 'bad ncontact' out_nclash[ialn0, ialn1] = nclash out_ncontact[ialn0, ialn1] = ncontact return out_rms, out_nclash, out_ncontact @jit def _check_inorder(ires): for i in range(len(ires) - 1): if ires[i] > ires[i + 1]: return False return True @jit def _get_outidx(iout): outidx = np.empty(np.max(iout) + 1, dtype=np.int32) for i, o in enumerate(iout): outidx[o] = i return outidx @jit def _splice_respairs(edgemat, bbc, bbn): n = np.sum(edgemat) out0 = np.empty(n, dtype=np.int32) out1 = np.empty(n, dtype=np.int32) res0 = _ires_from_conn(bbc.connections, 1) res1 = _ires_from_conn(bbn.connections, 0) assert len(res0) == edgemat.shape[0] assert len(res1) == edgemat.shape[1] count = 0 for i in range(edgemat.shape[0]): for j in range(edgemat.shape[1]): if edgemat[i, j]: out0[count] = res0[i] out1[count] = res1[j] count += 1 assert count == n return out0, out1