Skip to content

Commit

Permalink
Updates Neo4jChatMessageHistory to use neo4j-graphrag queries
Browse files Browse the repository at this point in the history
  • Loading branch information
alexthomas93 committed Feb 26, 2025
1 parent 0240372 commit 4743719
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 34 deletions.
77 changes: 43 additions & 34 deletions libs/neo4j/langchain_neo4j/chat_message_histories/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, messages_from_dict
from langchain_core.utils import get_from_dict_or_env
from neo4j_graphrag.message_history import (
ADD_MESSAGE_QUERY,
CREATE_SESSION_NODE_QUERY,
DELETE_MESSAGES_QUERY,
DELETE_SESSION_AND_MESSAGES_QUERY,
GET_MESSAGES_QUERY,
)

from langchain_neo4j.graphs.neo4j_graph import Neo4jGraph

Expand Down Expand Up @@ -71,26 +78,27 @@ def __init__(
self._window = window
# Create session node
self._driver.execute_query(
f"MERGE (s:`{self._node_label}` {{id:$session_id}})",
CREATE_SESSION_NODE_QUERY.format(node_label=self._node_label),
{"session_id": self._session_id},
).summary
)

@property
def messages(self) -> List[BaseMessage]:
"""Retrieve the messages from Neo4j"""
query = (
f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) "
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.."
f"{self._window*2}]-() WITH p, length(p) AS length "
"ORDER BY length DESC LIMIT 1 UNWIND reverse(nodes(p)) AS node "
"RETURN {data:{content: node.content}, type:node.type} AS result"
)
records, _, _ = self._driver.execute_query(
query, {"session_id": self._session_id}
GET_MESSAGES_QUERY.format(
node_label=self._node_label, window=self._window * 2
),
{"session_id": self._session_id},
)

messages = messages_from_dict([el["result"] for el in records])
return messages
messages = [
{
"data": el["result"]["data"],
"type": el["result"]["role"],
}
for el in records
]
return messages_from_dict(messages)

@messages.setter
def messages(self, messages: List[BaseMessage]) -> None:
Expand All @@ -101,33 +109,34 @@ def messages(self, messages: List[BaseMessage]) -> None:

def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in Neo4j"""
query = (
f"MATCH (s:`{self._node_label}`) WHERE s.id = $session_id "
"OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message) "
"CREATE (s)-[:LAST_MESSAGE]->(new:Message) "
"SET new += {type:$type, content:$content} "
"WITH new, lm, last_message WHERE last_message IS NOT NULL "
"CREATE (last_message)-[:NEXT]->(new) "
"DELETE lm"
)
self._driver.execute_query(
query,
ADD_MESSAGE_QUERY.format(node_label=self._node_label),
{
"type": message.type,
"role": message.type,
"content": message.content,
"session_id": self._session_id,
},
).summary

def clear(self) -> None:
"""Clear session memory from Neo4j"""
query = (
f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) "
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT]-() "
"WITH p, length(p) AS length ORDER BY length DESC LIMIT 1 "
"UNWIND nodes(p) as node DETACH DELETE node;"
)
self._driver.execute_query(query, {"session_id": self._session_id}).summary

def clear(self, delete_session_node: bool = False) -> None:
"""Clear session memory from Neo4j
Args:
delete_session_node (bool): Whether to delete the session node.
Defaults to False.
"""
if delete_session_node:
self._driver.execute_query(
query_=DELETE_SESSION_AND_MESSAGES_QUERY.format(
node_label=self._node_label
),
parameters_={"session_id": self._session_id},
)
else:
self._driver.execute_query(
query_=DELETE_MESSAGES_QUERY.format(node_label=self._node_label),
parameters_={"session_id": self._session_id},
)

def __del__(self) -> None:
if self._driver:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,45 @@ def test_invalid_credentials() -> None:
assert "Please ensure that the username and password are correct" in str(
exc_info.value
)


def test_neo4j_message_history_clear_messages() -> None:
message_history = Neo4jChatMessageHistory(
session_id="123", url=url, username=username, password=password
)
message_history.add_messages(
[
HumanMessage(content="You are a helpful assistant."),
AIMessage(content="Hello"),
]
)
assert len(message_history.messages) == 2
message_history.clear()
assert len(message_history.messages) == 0
# Test that the session node is not deleted
results = message_history._driver.execute_query(
query_="MATCH (s:`Session`) WHERE s.id = '123' RETURN s"
)
assert len(results.records) == 1
assert results.records[0]["s"]["id"] == "123"
assert list(results.records[0]["s"].labels) == ["Session"]


def test_neo4j_message_history_clear_session_and_messages() -> None:
message_history = Neo4jChatMessageHistory(
session_id="123", url=url, username=username, password=password
)
message_history.add_messages(
[
HumanMessage(content="You are a helpful assistant."),
AIMessage(content="Hello"),
]
)
assert len(message_history.messages) == 2
message_history.clear(delete_session_node=True)
assert len(message_history.messages) == 0
# Test that the session node is deleted
results = message_history._driver.execute_query(
query_="MATCH (s:`Session`) WHERE s.id = '123' RETURN s"
)
assert results.records == []

0 comments on commit 4743719

Please sign in to comment.