Source code for worms.criteria.cyclic

from .base import *
from worms.util import jit
from worms.math import numba_axis_angle_single


[docs]class Cyclic(WormCriteria): def __init__( self, symmetry=1, from_seg=0, *, tol=1.0, origin_seg=None, lever=50.0, to_seg=-1, min_radius=0, ): if from_seg == to_seg: raise ValueError('from_seg should not be same as to_seg') if from_seg == origin_seg: raise ValueError('from_seg should not be same as origin_seg') if to_seg == origin_seg: raise ValueError('to_seg should not be same as origin_seg') if isinstance(symmetry, int): symmetry = 'C' + str(symmetry) self.symmetry = symmetry self.tol = tol self.from_seg = from_seg self.origin_seg = origin_seg self.lever = lever self.to_seg = to_seg self.rot_tol = tol / lever # self.relweight = relweight if abs(relweight) > 0.001 else None if self.symmetry[0] in 'cC': self.nfold = int(self.symmetry[1:]) if self.nfold <= 0: raise ValueError('invalid symmetry: ' + symmetry) self.symangle = np.pi * 2.0 / self.nfold else: raise ValueError('can only do Cx symmetry for now') if self.tol <= 0: raise ValueError('tol should be > 0') self.last_body_same_as = self.from_seg self.is_cyclic = True self.symname = None if self.nfold > 1: self.symname = 'C' + str(self.nfold) self.sym_axes = [(self.nfold, Uz, [0, 0, 0, 1])] a = self.symangle if self.nfold == 1: self.min_sep2 = 0.0 elif self.nfold == 2: self.min_sep2 = 2.0 * min_radius else: self.min_sep2 = min_radius * np.sin(a) / np.sin((np.pi - a) / 2) self.min_sep2 = self.min_sep2**2
[docs] def score(self, segpos, *, verbosity=False, **kw): x_from = segpos[self.from_seg] x_to = segpos[self.to_seg] xhat = x_to @ inv(x_from) trans = xhat[..., :, 3] if self.nfold is 1: angle = hm.angle_of(xhat) carterrsq = np.sum(trans[..., :3]**2, axis=-1) roterrsq = angle**2 else: if self.origin_seg is not None: tgtaxis = segpos[self.origin_seg] @ [0, 0, 1, 0] tgtcen = segpos[self.origin_seg] @ [0, 0, 0, 1] axis, angle, cen = hm.axis_ang_cen_of(xhat) carterrsq = hm.hnorm2(cen - tgtcen) roterrsq = (1 - np.abs(hm.hdot(axis, tgtaxis))) * np.pi else: # much cheaper if cen not needed axis, angle = hm.axis_angle_of(xhat) carterrsq = roterrsq = 0 carterrsq = carterrsq + hm.hdot(trans, axis)**2 roterrsq = roterrsq + (angle - self.symangle)**2 carterrsq[np.sum(trans[..., :3]**2, axis=-1) < self.min_sep2] = 9e9 # if self.relweight is not None: # # penalize 'relative' error # distsq = np.sum(trans[..., :3]**2, axis=-1) # relerrsq = carterrsq / distsq # relerrsq[np.isnan(relerrsq)] = 9e9 # # too much of a hack?? # carterrsq += self.relweight * relerrsq if verbosity > 0: print('axis', axis[0]) print('trans', trans[0]) print('dot trans', hm.hdot(trans, axis)[0]) print('angle', angle[0] * 180 / np.pi) return np.sqrt(carterrsq / self.tol**2 + roterrsq / self.rot_tol**2)
[docs] def alignment(self, segpos, **kw): if self.origin_seg is not None: return inv(segpos[self.origin_seg]) x_from = segpos[self.from_seg] x_to = segpos[self.to_seg] xhat = x_to @ inv(x_from) axis, ang, cen = hm.axis_ang_cen_of(xhat) # print('aln', axis) # print('aln', ang * 180 / np.pi) # print('aln', cen) # print('aln', xhat[..., :, 3]) dotz = hm.hdot(axis, Uz)[..., None] tgtaxis = np.where(dotz > 0, [0, 0, 1, 0], [0, 0, -1, 0]) align = hm.hrot((axis + tgtaxis) / 2, np.pi, cen) align[..., :3, 3] -= cen[..., :3] return align
[docs] def jit_lossfunc(self): tgt_ang = self.symangle from_seg = self.from_seg to_seg = self.to_seg lever = self.lever min_sep2 = self.min_sep2 @jit def func(pos, idx, verts): x_from = pos[from_seg] x_to = pos[to_seg] xhat = x_to @ np.linalg.inv(x_from) if np.sum(xhat[:3, 3]**2) < min_sep2: return 9e9 axis, angle = numba_axis_angle_single(xhat) rot_err_sq = lever**2 * (angle - tgt_ang)**2 cart_err_sq = (np.sum(xhat[:, 3] * axis))**2 return np.sqrt(rot_err_sq + cart_err_sq) return func