diff --git a/forum/api/comments.py b/forum/api/comments.py index 470bb276..a5c477c1 100644 --- a/forum/api/comments.py +++ b/forum/api/comments.py @@ -11,6 +11,7 @@ from forum.backends.mongodb.api import ( create_comment, delete_comment_by_id, + get_thread_id_by_comment_id, get_thread_by_id, get_user_by_id, mark_as_read, @@ -18,6 +19,8 @@ update_comment_and_get_updated_comment, update_stats_for_course, ) +from forum.backends.mongodb.comments import Comment +from forum.backends.mongodb.threads import CommentThread from forum.serializers.comment import CommentSerializer from forum.utils import ForumV2RequestError @@ -120,6 +123,7 @@ def create_child_comment( anonymous, anonymous_to_peers, 1, + get_thread_id_by_comment_id(parent_comment_id), parent_id=parent_comment_id, ) if not comment: diff --git a/forum/api/users.py b/forum/api/users.py index 64b0e8ac..b8e7bd7b 100644 --- a/forum/api/users.py +++ b/forum/api/users.py @@ -6,12 +6,9 @@ from typing import Any from forum.backends.mongodb import Users -from forum.backends.mongodb.api import ( - get_group_ids_from_params, - user_to_hash, -) +from forum.backends.mongodb.api import user_to_hash from forum.serializers.users import UserSerializer -from forum.utils import ForumV2RequestError +from forum.utils import ForumV2RequestError, get_group_ids_from_params log = logging.getLogger(__name__) diff --git a/forum/backends/mongodb/api.py b/forum/backends/mongodb/api.py index 482515d8..fd8a0156 100644 --- a/forum/backends/mongodb/api.py +++ b/forum/backends/mongodb/api.py @@ -16,7 +16,6 @@ Subscriptions, Users, ) -from forum.utils import make_aware from forum.constants import RETIRED_BODY, RETIRED_TITLE from forum.utils import get_group_ids_from_params, get_sort_criteria, make_aware @@ -365,7 +364,7 @@ def get_abuse_flagged_count(thread_ids: list[str]) -> dict[str, int]: pipeline: list[dict[str, Any]] = [ { "$match": { - "comment_thread_id": {"$in": [tid for tid in thread_ids]}, + "comment_thread_id": {"$in": thread_ids}, "abuse_flaggers": {"$ne": []}, } }, @@ -430,7 +429,7 @@ def get_filtered_thread_ids( set: A set of filtered thread IDs based on the context and group ID criteria. """ context_query = { - "_id": {"$in": [tid for tid in thread_ids]}, + "_id": {"$in": thread_ids}, "context": context, } context_threads = CommentThread().find(context_query) @@ -440,7 +439,7 @@ def get_filtered_thread_ids( return context_thread_ids group_query = { - "_id": {"$in": [tid for tid in thread_ids]}, + "_id": {"$in": thread_ids}, "$or": [ {"group_id": {"$in": group_ids}}, {"group_id": {"$exists": False}}, @@ -464,7 +463,7 @@ def get_endorsed(thread_ids: list[str]) -> dict[str, bool]: """ endorsed_comments = Comment().find( { - "comment_thread_id": {"$in": [tid for tid in thread_ids]}, + "comment_thread_id": {"$in": thread_ids}, "endorsed": True, } ) @@ -1034,7 +1033,7 @@ def user_to_hash( comment_thread_ids = filter_standalone_threads(list(comments)) group_query = { - "_id": {"$in": [tid for tid in comment_thread_ids]}, + "_id": {"$in": comment_thread_ids}, "$and": [ {"group_id": {"$in": specified_groups_or_global}}, {"group_id": {"$exists": False}}, @@ -1310,7 +1309,7 @@ def create_comment( anonymous: bool, anonymous_to_peers: bool, depth: int, - thread_id: Optional[str] = None, + thread_id: str, parent_id: Optional[str] = None, ) -> Any: """ @@ -1406,3 +1405,13 @@ def update_comment_and_get_updated_comment( def delete_comment_by_id(comment_id: str) -> None: """Delete a comment by it's Id.""" Comment().delete(comment_id) + + +def get_thread_id_by_comment_id(parent_comment_id: str) -> str: + """ + The thread Id from the parent comment. + """ + parent_comment = Comment().get(parent_comment_id) + if parent_comment: + return parent_comment["comment_thread_id"] + raise ValueError("Comment doesn't have the thread.") diff --git a/forum/backends/mongodb/comments.py b/forum/backends/mongodb/comments.py index e2a2f650..4212019b 100644 --- a/forum/backends/mongodb/comments.py +++ b/forum/backends/mongodb/comments.py @@ -100,7 +100,7 @@ def insert( """ date = datetime.now() comment_data = { - "_id":str(ObjectId()), + "_id": str(ObjectId()), "votes": self.get_votes_dict(up=[], down=[]), "visible": visible, "abuse_flaggers": abuse_flaggers or [], diff --git a/forum/utils.py b/forum/utils.py index 97dd332b..0b613b74 100644 --- a/forum/utils.py +++ b/forum/utils.py @@ -171,7 +171,6 @@ def get_group_ids_from_params(params: dict[str, Any]) -> list[int]: Raises: ValueError: If both `group_id` and `group_ids` are specified in the parameters. """ - if "group_id" in params and "group_ids" in params: raise ValueError("Cannot specify both group_id and group_ids") group_ids: str | list[str] = [] @@ -221,10 +220,3 @@ def get_sort_criteria(sort_key: str) -> Sequence[tuple[str, int]]: class ForumV2RequestError(Exception): pass - - -def make_aware(dt: datetime) -> datetime: - """Make datetime timezone-aware.""" - if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) - return dt diff --git a/forum/views/comments.py b/forum/views/comments.py index 0985864f..122a77b9 100644 --- a/forum/views/comments.py +++ b/forum/views/comments.py @@ -14,7 +14,7 @@ get_parent_comment, update_comment, ) -from forum.utils import ForumV2RequestError +from forum.utils import ForumV2RequestError, str_to_bool class CommentsAPIView(APIView): diff --git a/forum/views/users.py b/forum/views/users.py index 6738c122..8da15a7c 100644 --- a/forum/views/users.py +++ b/forum/views/users.py @@ -27,7 +27,6 @@ ) from forum.serializers.thread import ThreadSerializer from forum.serializers.users import UserSerializer -from forum.utils import get_group_ids_from_params from forum.utils import ForumV2RequestError log = logging.getLogger(__name__)