Skip to content

Commit

Permalink
Linting.
Browse files Browse the repository at this point in the history
  • Loading branch information
hmcezar committed Dec 19, 2024
1 parent 67b5e68 commit a8806a6
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 44 deletions.
4 changes: 3 additions & 1 deletion clusttraj/distmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ def compute_distmat_line(
# whereins = np.where(
# np.isin(np.arange(natoms), reorderexcl[soluexcl]) is True
# )
whereins = np.where(np.atleast_1d(np.isin(np.arange(natoms), reorderexcl[soluexcl])))
whereins = np.where(

Check warning on line 198 in clusttraj/distmat.py

View check run for this annotation

Codecov / codecov/patch

clusttraj/distmat.py#L198

Added line #L198 was not covered by tests
np.atleast_1d(np.isin(np.arange(natoms), reorderexcl[soluexcl]))
)
Psolu = np.insert(
Pview,
[x - whereins[0].tolist().index(x) for x in whereins[0]],
Expand Down
32 changes: 20 additions & 12 deletions clusttraj/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,9 +544,11 @@ def parse_args(args: argparse.Namespace) -> ClustOptions:

options_dict = {
"solute_natoms": args.natoms_solute,
"reorder_excl": np.asarray([x - 1 for x in args.reorder_exclusions], np.int32)
if args.reorder_exclusions
else np.asarray([], np.int32),
"reorder_excl": (
np.asarray([x - 1 for x in args.reorder_exclusions], np.int32)
if args.reorder_exclusions
else np.asarray([], np.int32)
),
"exclusions": bool(args.reorder_exclusions),
"reorder_alg_name": args.reorder_alg,
"reorder_alg": None,
Expand All @@ -556,12 +558,12 @@ def parse_args(args: argparse.Namespace) -> ClustOptions:
"out_clust_name": args.outputclusters,
"summary_name": basenameout + ".out",
"save_confs": bool(args.clusters_configurations),
"out_conf_name": basenameout + "_confs"
if args.clusters_configurations
else None,
"out_conf_fmt": args.clusters_configurations
if args.clusters_configurations
else None,
"out_conf_name": (
basenameout + "_confs" if args.clusters_configurations else None
),
"out_conf_fmt": (
args.clusters_configurations if args.clusters_configurations else None
),
"plot": bool(args.plot),
"evo_name": basenameout + "_evo.pdf" if args.plot else None,
"dendrogram_name": basenameout + "_dendrogram.pdf" if args.plot else None,
Expand Down Expand Up @@ -777,7 +779,9 @@ def save_clusters_config(
# whereins = np.where(
# np.isin(np.arange(natoms), reorderexcl[soluexcl]) is True
# )
whereins = np.where(np.atleast_1d(np.isin(np.arange(natoms), reorderexcl)))
whereins = np.where(

Check warning on line 782 in clusttraj/io.py

View check run for this annotation

Codecov / codecov/patch

clusttraj/io.py#L782

Added line #L782 was not covered by tests
np.atleast_1d(np.isin(np.arange(natoms), reorderexcl))
)
Psolu = np.insert(
Pview,
[x - whereins[0].tolist().index(x) for x in whereins[0]],
Expand Down Expand Up @@ -850,7 +854,9 @@ def save_clusters_config(

# build the total molecule with the reordered atoms
# whereins = np.where(np.isin(np.arange(len(P)), reorderexcl) is True)
whereins = np.where(np.atleast_1d(np.isin(np.arange(len(P)), reorderexcl)))
whereins = np.where(

Check warning on line 857 in clusttraj/io.py

View check run for this annotation

Codecov / codecov/patch

clusttraj/io.py#L857

Added line #L857 was not covered by tests
np.atleast_1d(np.isin(np.arange(len(P)), reorderexcl))
)
Pr = np.insert(
Pview,
[x - whereins[0].tolist().index(x) for x in whereins[0]],
Expand Down Expand Up @@ -883,4 +889,6 @@ def save_clusters_config(

# closes the file for the cnum cluster
outfile.close()
# type: ignore


# type: ignore
3 changes: 0 additions & 3 deletions clusttraj/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,16 @@
from scipy.cluster.hierarchy import cophenet
from typing import Tuple
import numpy as np
# from .io import ClustOptions


def compute_metrics(
# clust_opt: ClustOptions,
distmat: np.ndarray,
z_matrix: np.ndarray,
clusters: np.ndarray,
) -> Tuple[np.float64, np.float64, np.float64, np.float64]:
"""Compute metrics to assess the performance of the clustering procedure.
Args:
# clust_opt (ClustOptions): The clustering options.
distmat: The distance matrix.
z_matrix (np.ndarray): The Z-matrix from hierarchical clustering procedure.
clusters (np.ndarray): The cluster classifications for each sample.
Expand Down
55 changes: 27 additions & 28 deletions clusttraj/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
from .io import ClustOptions


def plot_clust_evo(
clust_opt: ClustOptions,
clusters: np.ndarray
) -> None:
def plot_clust_evo(clust_opt: ClustOptions, clusters: np.ndarray) -> None:
"""Plot the evolution of cluster classification over the given samples.
Args:
Expand All @@ -25,30 +22,34 @@ def plot_clust_evo(
Returns:
None
"""

# Define a color for the lines
line_color = (0, 0, 0, 0.5)

# plot evolution with o cluster in trajectory
plt.figure(figsize=(10, 6))

# Set the y-axis to only show integers
plt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))

# Increase tick size and font size
plt.tick_params(axis='both', which='major', direction='in', labelsize=12)
plt.tick_params(axis="both", which="major", direction="in", labelsize=12)

plt.plot(range(1, len(clusters) + 1), clusters, markersize=4, color=line_color)
plt.scatter(range(1, len(clusters) + 1), clusters, marker="o", c=clusters, cmap=plt.cm.nipy_spectral)
plt.scatter(
range(1, len(clusters) + 1),
clusters,
marker="o",
c=clusters,
cmap=plt.cm.nipy_spectral,
)
plt.xlabel("Sample Index", fontsize=14)
plt.ylabel("Cluster classification", fontsize=14)
plt.savefig(clust_opt.evo_name, bbox_inches="tight")


def plot_dendrogram(
clust_opt: ClustOptions,
clusters: np.ndarray,
Z: np.ndarray
clust_opt: ClustOptions, clusters: np.ndarray, Z: np.ndarray
) -> None:
"""Plot a dendrogram based on hierarchical clustering.
Expand All @@ -65,18 +66,22 @@ def plot_dendrogram(
plt.title("Hierarchical Clustering Dendrogram", fontsize=20)
# plt.xlabel("Sample Index", fontsize=14)
plt.ylabel(r"RMSD ($\AA$)", fontsize=18)
plt.tick_params(axis='y', labelsize=18)
plt.tick_params(axis="y", labelsize=18)

# Define a color for the dashed and non-cluster lines
line_color = (0, 0, 0, 0.5)

# Add a horizontal line at the minimum RMSD value and set the threshold
if clust_opt.silhouette_score:
if isinstance(clust_opt.optimal_cut, (np.ndarray, list)):
plt.axhline(clust_opt.optimal_cut[0], linestyle="--", linewidth=2, color=line_color)
plt.axhline(

Check warning on line 77 in clusttraj/plot.py

View check run for this annotation

Codecov / codecov/patch

clusttraj/plot.py#L77

Added line #L77 was not covered by tests
clust_opt.optimal_cut[0], linestyle="--", linewidth=2, color=line_color
)
threshold = clust_opt.optimal_cut[0]
elif isinstance(clust_opt.optimal_cut, (float, np.float32, np.float64)):
plt.axhline(clust_opt.optimal_cut, linestyle="--", linewidth=2, color=line_color)
plt.axhline(

Check warning on line 82 in clusttraj/plot.py

View check run for this annotation

Codecov / codecov/patch

clusttraj/plot.py#L82

Added line #L82 was not covered by tests
clust_opt.optimal_cut, linestyle="--", linewidth=2, color=line_color
)
threshold = clust_opt.optimal_cut
else:
raise ValueError("optimal_cut must be a float or np.ndarray")
Expand All @@ -86,9 +91,9 @@ def plot_dendrogram(

# Use the 'nipy_spectral' cmap to color the dendrogram
unique_clusters = np.unique(clusters)
cmap = cm.get_cmap('nipy_spectral', len(unique_clusters))
cmap = cm.get_cmap("nipy_spectral", len(unique_clusters))
colors = [to_hex(cmap(i)) for i in range(cmap.N)]

hierarchy.set_link_color_palette(colors)

# Plot the dendrogram
Expand All @@ -98,18 +103,14 @@ def plot_dendrogram(
# leaf_font_size=8.0, # Font size for the x axis labels
no_labels=True,
color_threshold=threshold,
above_threshold_color=line_color
above_threshold_color=line_color,
)

# Save the dendrogram to a file
plt.savefig(clust_opt.dendrogram_name, bbox_inches="tight")


def plot_mds(
clust_opt: ClustOptions,
clusters: np.ndarray,
distmat: np.ndarray
) -> None:
def plot_mds(clust_opt: ClustOptions, clusters: np.ndarray, distmat: np.ndarray) -> None:
"""Plot the multidimensional scaling (MDS) of the distance matrix.
Args:
Expand Down Expand Up @@ -139,7 +140,7 @@ def plot_mds(
coords = mds.fit_transform(squareform(distmat))

# Set the figure size
plt.figure(figsize=(6,6))
plt.figure(figsize=(6, 6))

# Configure tick parameters
plt.tick_params(
Expand All @@ -165,9 +166,7 @@ def plot_mds(


def plot_tsne(
clust_opt: ClustOptions,
clusters: np.ndarray,
distmat: np.ndarray
clust_opt: ClustOptions, clusters: np.ndarray, distmat: np.ndarray
) -> None:
"""Plot the t-distributed Stochastic Neighbor Embedding 2D plot of the clustering.
Expand All @@ -194,11 +193,11 @@ def plot_tsne(

# Define a list of unique colors for each cluster
unique_clusters = np.unique(clusters)
cmap = cm.get_cmap('nipy_spectral', len(unique_clusters))
cmap = cm.get_cmap("nipy_spectral", len(unique_clusters))

Check warning on line 196 in clusttraj/plot.py

View check run for this annotation

Codecov / codecov/patch

clusttraj/plot.py#L196

Added line #L196 was not covered by tests
colors = [cmap(i) for i in range(len(unique_clusters))]

# Set the figure size
plt.figure(figsize=(6,6))
plt.figure(figsize=(6, 6))

Check warning on line 200 in clusttraj/plot.py

View check run for this annotation

Codecov / codecov/patch

clusttraj/plot.py#L200

Added line #L200 was not covered by tests

# Configure tick parameters
plt.tick_params(
Expand Down

0 comments on commit a8806a6

Please sign in to comment.