Skip to content

Commit

Permalink
Tidying and removing redundant for-loop
Browse files Browse the repository at this point in the history
  • Loading branch information
Glycocalex committed Dec 3, 2024
1 parent 034b6ad commit f7e37ea
Showing 1 changed file with 53 additions and 91 deletions.
144 changes: 53 additions & 91 deletions glycowork/glycan_data/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,47 +263,27 @@ def jtkdist(timepoints: Union[int, np.ndarray], # number/array of timepoints wit
normal: bool = False # flag for normal approximation if max possible negative log p-value too large
) -> Dict: # updated param_dic with statistical values
"Precalculates all possible JT test statistic permutation probabilities for reference using the Harding algorithm"
timepoints = timepoints if isinstance(timepoints, int) else timepoints.sum()
tim = np.full(timepoints, reps) if reps != timepoints else reps # Support for unbalanced replication (unequal replicates in all groups)
if np.max(tim) > 0:
range_array = np.arange(1, np.max(tim)+1)
if len(range_array) > 0:
log_sum = np.sum(np.log(range_array))
maxnlp = gammaln(np.sum(tim)) - log_sum
else:
maxnlp = 0
else:
maxnlp = 0
limit = math.log(float('inf'))
normal = normal or (maxnlp > limit - 1) # Switch to normal approximation if maxnlp is too large
tim = np.full(timepoints, reps)
nn = sum(tim) # Number of data values (Independent of period and lag)
M = (nn ** 2 - np.sum(np.square(tim))) * 0.5 if nn > 0 else 0 # Max possible jtk statistic
param_dic.update({"GRP_SIZE": tim, "NUM_GRPS": len(tim), "NUM_VALS": nn,
"MAX": M, "DIMS": [int(nn * (nn - 1) * 0.5), 1 if nn > 1 else [0, 0]]})
if normal:
if nn > 0:
squared_terms = np.square(tim) * (2 * tim + 3)
var = (nn ** 2 * (2 * nn + 3) - np.sum(squared_terms)) / 72
else:
var = 0
param_dic["VAR"] = var # Variance of JTK
param_dic["SDV"] = np.sqrt(max(var, 0.0)) # Standard deviation of JTK
param_dic["EXV"] = M * 0.5 # Expected value of JTK
param_dic["EXACT"] = False
squared_terms = np.square(tim) * (2 * tim + 3)
var = (nn ** 2 * (2 * nn + 3) - np.sum(squared_terms)) / 72
param_dic["VAR"] = var # Variance of JTK
param_dic["SDV"] = np.sqrt(max(var, 0.0)) # Standard deviation of JTK
param_dic["EXV"] = M * 0.5 # Expected value of JTK
param_dic["EXACT"] = False
MM = int(M // 2) # Mode of this possible alternative to JTK distribution
cf = [1] * (MM + 1) # Initial lower half cumulative frequency (cf) distribution
size = sorted(tim) # Sizes of each group of known replicate values, in ascending order for fastest calculation
k = len(tim) # Number of groups of replicates
if k > 1:
N = [size[k-1]]
for i in range(k - 1, 1, -1): # Count permutations using the Harding algorithm
N.insert(0, (size[i] + N[0]))
else:
N = [size[0]] if k == 1 else []
N = [size[k-1]]
for i in range(k - 1, 1, -1): # Count permutations using the Harding algorithm
N.insert(0, (size[i] + N[0]))
for m, n in zip(size[:-1], N):
update_cf_for_m_n(m, n, MM, cf)
cf = np.array(cf)
# cf now contains the lower half cumulative frequency distribution
cf = np.array(cf) # cf now contains the lower half cumulative frequency distribution
# append the symmetric upper half cumulative frequency distribution to cf
if M % 2: # jtkcf = upper-tail cumulative frequencies for all integer jtk
jtkcf = np.concatenate((cf, 2 * cf[MM] - cf[:MM][::-1], [2 * cf[MM]]))[::-1]
Expand Down Expand Up @@ -335,58 +315,43 @@ def jtkinit(periods: List[int], # possible periods of rhythmicity in biological
time2angle = np.array([(2*round(math.pi, 4)) / period]) # convert time to angle using an ~pi value
theta = timerange * time2angle # zero-based angular values across time indices
cos_v = np.cos(theta) # unique cosine values at each time point
if len(cos_v) > 0:
ranked = rankdata(cos_v)
cos_r = np.repeat(ranked, np.max(tim)) if np.max(tim) > 0 else ranked # replicated ranks of unique cosine values
else:
cos_r = np.array([])
if len(cos_r) > 0:
cgoos = np.sign(np.subtract.outer(cos_r, cos_r)).astype(int)
lower_tri = []
for col in range(len(cgoos)):
for row in range(col + 1, len(cgoos)):
lower_tri.append(cgoos[row, col])
cgoos = np.array(lower_tri)
if len(cgoos) > 0:
cgoosv = np.array(cgoos).reshape(param_dic["DIMS"])
period_array = np.zeros((cgoos.shape[0], period))
period_array[:, 0] = cgoosv[:, 0]
param_dic["CGOOSV"].append(period_array)
else:
param_dic["CGOOSV"].append(np.zeros((1, period)))
else:
param_dic["CGOOSV"].append(np.zeros((1, period)))
ranked = rankdata(cos_v)
cos_r = np.repeat(ranked, np.max(tim)) if np.max(tim) > 0 else ranked # replicated ranks of unique cosine values
cgoos = np.sign(np.subtract.outer(cos_r, cos_r)).astype(int)
lower_tri = []
for col in range(len(cgoos)):
for row in range(col + 1, len(cgoos)):
lower_tri.append(cgoos[row, col])
cgoos = np.array(lower_tri)
cgoosv = np.array(cgoos).reshape(param_dic["DIMS"])
period_array = np.zeros((cgoos.shape[0], period))
period_array[:, 0] = cgoosv[:, 0]
param_dic["CGOOSV"].append(period_array)
cycles = math.floor(timepoints / period)
jrange = np.arange(cycles * period)
if len(cos_v) > 0:
cos_s = np.sign(cos_v)[jrange]
cos_s = np.repeat(cos_s, (tim[jrange]))
if replicates == 1:
param_dic["SIGNCOS"][:len(cos_s), i] = cos_s
else:
param_dic["SIGNCOS"][i, :len(cos_s)] = cos_s
for j in range(1, period): # One-based half-integer lag index j
delta_theta = j * time2angle / 2 # Angles of half-integer lags
cos_v = np.cos(theta + delta_theta) # Cycle left
cos_r = np.concatenate([np.repeat(val, num) for val, num in zip(rankdata(cos_v), tim)]) # Phase-shifted replicated ranks
cgoos = np.sign(np.subtract.outer(cos_r, cos_r)).T
mask = np.triu(np.ones(cgoos.shape), k = 1).astype(bool)
mask[np.diag_indices(mask.shape[0])] = False
cgoos = cgoos[mask]
cgoosv = cgoos.reshape(param_dic["DIMS"])
param_dic["CGOOSV"][i][:, j] = cgoosv.flatten()
cos_v = cos_v.flatten()
cos_s = np.sign(cos_v)[jrange]
if len(tim[jrange]) > 0:
cos_s = np.repeat(cos_s, (tim[jrange]))
cos_s = np.repeat(cos_s, (tim[jrange]))
if replicates == 1:
param_dic["SIGNCOS"][:len(cos_s), i] = cos_s
param_dic["SIGNCOS"][:len(cos_s), j] = cos_s
else:
param_dic["SIGNCOS"][i, :len(cos_s)] = cos_s
for j in range(1, period): # One-based half-integer lag index j
delta_theta = j * time2angle / 2 # Angles of half-integer lags
cos_v = np.cos(theta + delta_theta) # Cycle left
if len(cos_v) > 0:
cos_r = np.concatenate([np.repeat(val, num) for val, num in zip(rankdata(cos_v), tim)]) # Phase-shifted replicated ranks
if len(cos_r) > 0:
cgoos = np.sign(np.subtract.outer(cos_r, cos_r)).T
mask = np.triu(np.ones(cgoos.shape), k = 1).astype(bool)
mask[np.diag_indices(mask.shape[0])] = False
cgoos = cgoos[mask]
if len(cgoos) > 0:
cgoosv = cgoos.reshape(param_dic["DIMS"])
param_dic["CGOOSV"][i][:, j] = cgoosv.flatten()
cos_v = cos_v.flatten()
cos_s = np.sign(cos_v)[jrange]
if len(tim[jrange]) > 0:
cos_s = np.repeat(cos_s, (tim[jrange]))
if replicates == 1:
param_dic["SIGNCOS"][:len(cos_s), j] = cos_s
else:
param_dic["SIGNCOS"][j, :len(cos_s)] = cos_s
param_dic["SIGNCOS"][j, :len(cos_s)] = cos_s
return param_dic


Expand All @@ -399,20 +364,17 @@ def jtkstat(z: pd.DataFrame, # expression data for a molecule ordered in groups
z = np.array(z).flatten()
valid_mask = ~np.isnan(z)
z_valid = z[valid_mask]
if len(z_valid) > 1:
foosv = np.sign(np.subtract.outer(z_valid, z_valid)).T # Due to differences in the triangle indexing of R / Python we need to transpose and select upper triangle rather than the lower triangle
mask = np.triu(np.ones(foosv.shape), k = 1).astype(bool) # Additionally, we need to remove the middle diagonal from the tri index
mask[np.diag_indices(mask.shape[0])] = False
foosv = foosv[mask]
expected_dims = param_dic["DIMS"][0] * param_dic["DIMS"][1]
if len(foosv) == expected_dims:
foosv = foosv.reshape(param_dic["DIMS"])
else:
temp_foosv = np.zeros(expected_dims)
temp_foosv[:len(foosv)] = foosv[:expected_dims]
foosv = temp_foosv.reshape(param_dic["DIMS"])
foosv = np.sign(np.subtract.outer(z_valid, z_valid)).T # Due to differences in the triangle indexing of R / Python we need to transpose and select upper triangle rather than the lower triangle
mask = np.triu(np.ones(foosv.shape), k = 1).astype(bool) # Additionally, we need to remove the middle diagonal from the tri index
mask[np.diag_indices(mask.shape[0])] = False
foosv = foosv[mask]
expected_dims = param_dic["DIMS"][0] * param_dic["DIMS"][1]
if len(foosv) == expected_dims:
foosv = foosv.reshape(param_dic["DIMS"])
else:
foosv = np.zeros(param_dic["DIMS"])
temp_foosv = np.zeros(expected_dims)
temp_foosv[:len(foosv)] = foosv[:expected_dims]
foosv = temp_foosv.reshape(param_dic["DIMS"])
for i in range(param_dic["PERIODS"][0]):
if i < len(param_dic["CGOOSV"][0]):
cgoosv = param_dic["CGOOSV"][0][i]
Expand Down

0 comments on commit f7e37ea

Please sign in to comment.