Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CNDB-12980: Clone the vectors issued by the postingsMap in CompactionGraph #1586

Merged
merged 2 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,9 @@ public InsertionResult maybeAddVector(ByteBuffer term, int segmentRowId) throws
var trainingVectors = new ArrayList<VectorFloat<?>>(postingsMap.size());
var vectorsByOrdinal = new Int2ObjectHashMap<VectorFloat<?>>();
postingsMap.forEach((v, p) -> {
trainingVectors.add(v);
vectorsByOrdinal.put(p.getOrdinal(), v);
var vectorClone = v.copy();
trainingVectors.add(vectorClone);
vectorsByOrdinal.put(p.getOrdinal(), vectorClone);
});

// lock the addGraphNode threads out so they don't try to use old pq codepoints against the new codebook
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@
package org.apache.cassandra.index.sai.cql;

import java.util.ArrayList;
import java.util.stream.Collectors;

import org.junit.Test;

import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import org.apache.cassandra.db.marshal.FloatType;
import org.apache.cassandra.index.sai.disk.v3.V3OnDiskFormat;
import org.apache.cassandra.index.sai.disk.v5.V5VectorPostingsWriter;
import org.apache.cassandra.index.sai.disk.vector.CompactionGraph;

Expand Down Expand Up @@ -59,22 +63,38 @@ public void testPQRefine()
createTable();
disableCompaction();

var vectors = new ArrayList<float[]>();
// 3 sstables
for (int j = 0; j < 3; j++)
{
for (int i = 0; i <= MIN_PQ_ROWS; i++)
{
var pk = j * MIN_PQ_ROWS + i;
execute("INSERT INTO %s (pk, v) VALUES (?, ?)", pk, vector(pk, pk + 1));
var v = create2DVector();
vectors.add(v);
execute("INSERT INTO %s (pk, v) VALUES (?, ?)", pk, vector(v));
}
flush();
}

CompactionGraph.PQ_TRAINING_SIZE = 2 * MIN_PQ_ROWS;
compact();

// Confirm we can query the data
assertRowCount(execute("SELECT * FROM %s ORDER BY v ANN OF [1,2] LIMIT 1"), 1);
// Confirm we can query the data with reasonable recall
double recall = 0;
int ITERS = 10;
for (int i = 0; i < ITERS; i++)
{
var q = create2DVector();
var result = execute("SELECT pk, v FROM %s ORDER BY v ANN OF ? LIMIT 20", vector(q));
var ann = result.stream().map(row -> {
var vList = row.getVector("v", FloatType.instance, 2);
return new float[]{ vList.get(0), vList.get(1) };
}).collect(Collectors.toList());
recall += computeRecall(vectors, q, ann, VectorSimilarityFunction.COSINE);
}
recall /= ITERS;
assert recall >= 0.9 : recall;
}

@Test
Expand Down Expand Up @@ -179,4 +199,9 @@ private void validateQueries()
}
}
}

private static float[] create2DVector() {
var R = getRandom();
return new float[] { R.nextFloatBetween(-100, 100), R.nextFloatBetween(-100, 100) };
}
}