Skip to content

Commit

Permalink
added a bunch of trial stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
epinzur committed Sep 20, 2024
1 parent 48b7f16 commit b9bee14
Show file tree
Hide file tree
Showing 15 changed files with 3,465 additions and 150 deletions.
20 changes: 20 additions & 0 deletions libs/knowledge-store/ragstack_knowledge_store/compare_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pickle
from langchain_core.documents import Document

from typing import Dict, List

def get_stuff(table_name):
with open(f"debug_retrieval_{table_name}.pkl", "rb") as file:
return pickle.load(file)


metadata_based: Dict[str, List[Document]] = get_stuff("metadata_based")
link_based: Dict[str, List[Document]] = get_stuff("link_column_based")

count = 1
for query in metadata_based.keys():
metadata_chunks = metadata_based[query]
link_chunks = link_based[query]

print(f"Query {count} has {len(metadata_chunks)} metadata chunks and {len(link_chunks)} link chunks. Diff: {len(metadata_chunks)-len(link_chunks)}")
count += 1
143 changes: 143 additions & 0 deletions libs/knowledge-store/ragstack_knowledge_store/concurrency copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from __future__ import annotations

import contextlib
import logging
import threading
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
NamedTuple,
Protocol,
Sequence,
)

if TYPE_CHECKING:
from types import TracebackType

from cassandra.cluster import ResponseFuture, Session
from cassandra.query import PreparedStatement, SimpleStatement

logger = logging.getLogger(__name__)


class _Callback(Protocol):
def __call__(self, rows: Sequence[Any], /) -> None:
...


class ConcurrentQueries(contextlib.AbstractContextManager["ConcurrentQueries"]):
"""Context manager for concurrent queries with a max limit of 5 ongoing queries."""

_MAX_CONCURRENT_QUERIES = 5

def __init__(self, session: Session) -> None:
self._session = session
self._completion = threading.Condition()
self._pending = 0
self._error: BaseException | None = None
self._semaphore = threading.Semaphore(self._MAX_CONCURRENT_QUERIES)

def _handle_result(
self,
result: Sequence[NamedTuple],
future: ResponseFuture,
callback: Callable[[Sequence[NamedTuple]], Any] | None,
) -> None:
if callback is not None:
callback(result)

if future.has_more_pages:
future.start_fetching_next_page()
else:
with self._completion:
self._pending -= 1
self._semaphore.release() # Release the semaphore once a query completes
if self._pending == 0:
self._completion.notify()

def _handle_error(self, error: BaseException, future: ResponseFuture) -> None:
logger.error(
"Error executing query: %s",
future.query,
exc_info=error,
)
with self._completion:
self._error = error
self._pending -= 1 # Decrement pending count
self._semaphore.release() # Release the semaphore on error
self._completion.notify()

def execute(
self,
query: PreparedStatement | SimpleStatement,
parameters: tuple[Any, ...] | None = None,
callback: _Callback | None = None,
timeout: float | None = None,
) -> None:
"""Execute a query concurrently with a max of 5 concurrent queries.
Args:
query: The query to execute.
parameters: Parameter tuple for the query. Defaults to `None`.
callback: Callback to apply to the results. Defaults to `None`.
timeout: Timeout to use (if not the session default).
"""
with self._completion:
if self._error is not None:
return

# Acquire the semaphore before proceeding to ensure we do not exceed the max limit
self._semaphore.acquire()

with self._completion:
if self._error is not None:
# Release semaphore before returning
self._semaphore.release()
return
self._pending += 1

try:
execute_kwargs = {}
if timeout is not None:
execute_kwargs["timeout"] = timeout
future: ResponseFuture = self._session.execute_async(
query,
parameters,
**execute_kwargs,
)
future.add_callbacks(
self._handle_result,
self._handle_error,
callback_kwargs={
"future": future,
"callback": callback,
},
errback_kwargs={
"future": future,
},
)
except Exception as e:
with self._completion:
self._error = e
self._pending -= 1 # Decrement pending count
self._semaphore.release() # Release semaphore
self._completion.notify()
raise

def __exit__(
self,
_exc_type: type[BaseException] | None,
_exc_inst: BaseException | None,
_exc_traceback: TracebackType | None,
) -> Literal[False]:
with self._completion:
while self._error is None and self._pending > 0:
self._completion.wait()

if self._error is not None:
raise self._error

# Don't swallow the exception.
return False
80 changes: 48 additions & 32 deletions libs/knowledge-store/ragstack_knowledge_store/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,27 @@
from types import TracebackType

from cassandra.cluster import ResponseFuture, Session
from cassandra.query import PreparedStatement
from cassandra.query import PreparedStatement, SimpleStatement

logger = logging.getLogger(__name__)


class _Callback(Protocol):
def __call__(self, rows: Sequence[Any], /) -> None: ...
def __call__(self, rows: Sequence[Any], /) -> None:
...


class ConcurrentQueries(contextlib.AbstractContextManager["ConcurrentQueries"]):
"""Context manager for concurrent queries."""
"""Context manager for concurrent queries with a max limit of 5 ongoing queries."""

_MAX_CONCURRENT_QUERIES = 5

def __init__(self, session: Session) -> None:
self._session = session
self._completion = threading.Condition()
self._pending = 0
self._error: BaseException | None = None
self._semaphore = threading.Semaphore(self._MAX_CONCURRENT_QUERIES)

def _handle_result(
self,
Expand All @@ -49,6 +53,7 @@ def _handle_result(
else:
with self._completion:
self._pending -= 1
self._semaphore.release() # Release the semaphore once a query completes
if self._pending == 0:
self._completion.notify()

Expand All @@ -60,53 +65,66 @@ def _handle_error(self, error: BaseException, future: ResponseFuture) -> None:
)
with self._completion:
self._error = error
self._pending -= 1 # Decrement pending count
self._semaphore.release() # Release the semaphore on error
self._completion.notify()

def execute(
self,
query: PreparedStatement,
query: PreparedStatement | SimpleStatement,
parameters: tuple[Any, ...] | None = None,
callback: _Callback | None = None,
timeout: float | None = None,
) -> None:
"""Execute a query concurrently.
Because this is done concurrently, it expects a callback if you need
to inspect the results.
"""Execute a query concurrently with a max of 5 concurrent queries.
Args:
query: The query to execute.
parameters: Parameter tuple for the query. Defaults to `None`.
callback: Callback to apply to the results. Defaults to `None`.
timeout: Timeout to use (if not the session default).
"""
# TODO: We could have some form of throttling, where we track the number
# of pending calls and queue things if it exceed some threshold.
with self._completion:
if self._error is not None:
return

# Acquire the semaphore before proceeding to ensure we do not exceed the max limit
self._semaphore.acquire()

with self._completion:
self._pending += 1
if self._error is not None:
# Release semaphore before returning
self._semaphore.release()
return
self._pending += 1

execute_kwargs = {}
if timeout is not None:
execute_kwargs["timeout"] = timeout
future: ResponseFuture = self._session.execute_async(
query,
parameters,
**execute_kwargs,
)
future.add_callbacks(
self._handle_result,
self._handle_error,
callback_kwargs={
"future": future,
"callback": callback,
},
errback_kwargs={
"future": future,
},
)
try:
execute_kwargs = {}
if timeout is not None:
execute_kwargs["timeout"] = timeout
future: ResponseFuture = self._session.execute_async(
query,
parameters,
**execute_kwargs,
)
future.add_callbacks(
self._handle_result,
self._handle_error,
callback_kwargs={
"future": future,
"callback": callback,
},
errback_kwargs={
"future": future,
},
)
except Exception as e:
with self._completion:
self._error = e
self._pending -= 1 # Decrement pending count
self._semaphore.release() # Release semaphore
self._completion.notify()
raise

def __exit__(
self,
Expand All @@ -122,6 +140,4 @@ def __exit__(
raise self._error

# Don't swallow the exception.
# We don't need to do anything with the exception (`_exc_*` parameters)
# since returning false here will automatically re-raise it.
return False
Loading

0 comments on commit b9bee14

Please sign in to comment.