Skip to content

Commit b9b1f73

Browse files
committed
Some issues fixed
#: 1 Severity: HIGH File: distance-avx512.c:877 Fix: Removed (n+7)/8 — n is already byte count from all callers, matching CPU/NEON/SSE2/AVX2 backends ──────────────────────────────────────── #: 2 Severity: MEDIUM File: sqlite-vector.c:1949 Fix: Added vector_allocated flag and sqlite3_free calls on all exit paths in vCursorFilterCommon when vector was allocated by vector_from_json ──────────────────────────────────────── #: 3 Severity: LOW File: sqlite-vector.c:1194 Fix: Swapped the ternary branches so is_without_rowid==true gets the "must have INTEGER PRIMARY KEY" error and is_without_rowid==false gets the "Out of memory" error ──────────────────────────────────────── #: 4 Severity: LOW File: sqlite-vector.c:1058 Fix: Added KEY_MATCH macro that checks key_len == sizeof(key)-1 before strncasecmp, preventing prefix matches like "ty" matching "type"
1 parent bfda256 commit b9b1f73

File tree

3 files changed

+39
-25
lines changed

3 files changed

+39
-25
lines changed

src/distance-avx512.c

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -873,17 +873,16 @@ static inline __m512i popcount_avx512(__m512i v) {
873873
}
874874

875875
// Hamming distance for 1-bit packed binary vectors
876-
// n = number of dimensions (bits), not bytes
876+
// n = number of bytes (callers pass (dimension + 7) / 8)
877877
static float bit1_distance_hamming_avx512(const void *v1, const void *v2, int n) {
878878
const uint8_t *a = (const uint8_t *)v1;
879879
const uint8_t *b = (const uint8_t *)v2;
880-
int num_bytes = (n + 7) / 8;
881880

882881
__m512i acc = _mm512_setzero_si512();
883882
int i = 0;
884883

885884
// Process 64 bytes at a time
886-
for (; i + 64 <= num_bytes; i += 64) {
885+
for (; i + 64 <= n; i += 64) {
887886
__m512i va = _mm512_loadu_si512((const __m512i *)(a + i));
888887
__m512i vb = _mm512_loadu_si512((const __m512i *)(b + i));
889888
__m512i xored = _mm512_xor_si512(va, vb);
@@ -904,7 +903,7 @@ static float bit1_distance_hamming_avx512(const void *v1, const void *v2, int n)
904903
uint64_t distance = _mm512_reduce_add_epi64(acc);
905904

906905
// Handle remaining bytes with scalar code
907-
for (; i < num_bytes; i++) {
906+
for (; i < n; i++) {
908907
#if defined(__GNUC__) || defined(__clang__)
909908
distance += __builtin_popcount(a[i] ^ b[i]);
910909
#else

src/sqlite-vector.c

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ SQLITE_EXTENSION_INIT1
100100
((int64_t)((uint8_t)(_ptr)[7]) << 56))
101101

102102
#define SWAP(_t, a, b) do { _t tmp = (a); (a) = (b); (b) = tmp; } while (0)
103+
#define KEY_MATCH(_k) (key_len == (int)sizeof(_k) - 1 && strncasecmp(key, _k, key_len) == 0)
103104

104105
#define VECTOR_COLUMN_IDX 0
105106
#define VECTOR_COLUMN_VECTOR 1
@@ -1054,47 +1055,47 @@ bool vector_keyvalue_callback (sqlite3_context *context, void *xdata, const char
10541055
char buffer[256] = {0};
10551056
size_t len = ((size_t)value_len > sizeof(buffer)-1) ? sizeof(buffer)-1 : (size_t)value_len;
10561057
memcpy(buffer, value, len);
1057-
1058-
if (strncasecmp(key, OPTION_KEY_TYPE, key_len) == 0) {
1058+
1059+
if (KEY_MATCH(OPTION_KEY_TYPE)) {
10591060
vector_type type = vector_name_to_type(buffer);
10601061
if (type == 0) return context_result_error(context, SQLITE_ERROR, "Invalid vector type: '%s' is not a recognized type", buffer);
10611062
options->v_type = type;
10621063
return true;
10631064
}
10641065

1065-
if (strncasecmp(key, OPTION_KEY_DIMENSION, key_len) == 0) {
1066+
if (KEY_MATCH(OPTION_KEY_DIMENSION)) {
10661067
int dimension = (int)strtol(buffer, NULL, 0);
10671068
if (dimension <= 0) return context_result_error(context, SQLITE_ERROR, "Invalid vector dimension: expected a positive integer, got '%s'", buffer);
10681069
options->v_dim = dimension;
10691070
return true;
10701071
}
10711072

1072-
if (strncasecmp(key, OPTION_KEY_NORMALIZED, key_len) == 0) {
1073+
if (KEY_MATCH(OPTION_KEY_NORMALIZED)) {
10731074
int normalized = (int)strtol(buffer, NULL, 0);
10741075
options->v_normalized = (normalized != 0);
10751076
return true;
10761077
}
10771078

1078-
if (strncasecmp(key, OPTION_KEY_MAXMEMORY, key_len) == 0) {
1079+
if (KEY_MATCH(OPTION_KEY_MAXMEMORY)) {
10791080
uint64_t max_memory = human_to_number(buffer);
10801081
if (max_memory > 0) options->max_memory = max_memory;
10811082
return true;
10821083
}
10831084

1084-
if (strncasecmp(key, OPTION_KEY_QUANTTYPE, key_len) == 0) {
1085+
if (KEY_MATCH(OPTION_KEY_QUANTTYPE)) {
10851086
vector_qtype type = quant_name_to_type(buffer);
10861087
if ((int)type == -1) return context_result_error(context, SQLITE_ERROR, "Invalid quantization type: '%s' is not a recognized or supported quantization type", buffer);
10871088
options->q_type = type;
10881089
return true;
10891090
}
10901091

1091-
if (strncasecmp(key, OPTION_KEY_DISTANCE, key_len) == 0) {
1092+
if (KEY_MATCH(OPTION_KEY_DISTANCE)) {
10921093
vector_distance type = distance_name_to_type(buffer);
10931094
if (type == 0) return context_result_error(context, SQLITE_ERROR, "Invalid distance name: '%s' is not a recognized or supported distance", buffer);
10941095
options->v_distance = type;
10951096
return true;
10961097
}
1097-
1098+
10981099
// means ignore unknown keys
10991100
return true;
11001101
}
@@ -1191,7 +1192,7 @@ void vector_context_add (sqlite3_context *context, vector_context *ctx, const ch
11911192

11921193
// sanity check primary key
11931194
if (!prikey) {
1194-
(is_without_rowid) ? context_result_error(context, SQLITE_NOMEM, "Out of memory: unable to duplicate rowid column name") : context_result_error(context, SQLITE_ERROR, "WITHOUT ROWID table '%s' must have exactly one PRIMARY KEY column of type INTEGER", table_name);
1195+
(is_without_rowid) ? context_result_error(context, SQLITE_ERROR, "WITHOUT ROWID table '%s' must have exactly one PRIMARY KEY column of type INTEGER", table_name) : context_result_error(context, SQLITE_NOMEM, "Out of memory: unable to duplicate rowid column name");
11951196
sqlite3_free(t_name);
11961197
sqlite3_free(c_name);
11971198
return;
@@ -1945,11 +1946,13 @@ static int vCursorFilterCommon (sqlite3_vtab_cursor *cur, int idxNum, const char
19451946
}
19461947

19471948
const void *vector = NULL;
1949+
bool vector_allocated = false;
19481950
int vsize = 0;
19491951
if (sqlite3_value_type(argv[2]) == SQLITE_TEXT) {
19501952
vsize = sqlite3_value_bytes(argv[2]);
19511953
vector = (const void *)vector_from_json(NULL, &vtab->base, t_ctx->options.v_type, (const char *)sqlite3_value_text(argv[2]), &vsize, t_ctx->options.v_dim);
19521954
if (!vector) return SQLITE_ERROR; // error already set inside vector_from_json
1955+
vector_allocated = true;
19531956
} else {
19541957
vector = (const void *)sqlite3_value_blob(argv[2]);
19551958
vsize = sqlite3_value_bytes(argv[2]);
@@ -1962,48 +1965,60 @@ static int vCursorFilterCommon (sqlite3_vtab_cursor *cur, int idxNum, const char
19621965
char *name = generate_quant_table_name(table_name, column_name, buffer);
19631966
if (!name || !sqlite_table_exists(vtab->db, name)) {
19641967
sqlite_vtab_set_error(&vtab->base, "Quantization table not found for table '%s' and column '%s'. Ensure that vector_quantize() has been called before using vector_quantize_scan()", table_name, column_name);
1968+
if (vector_allocated) sqlite3_free((void *)vector);
19651969
return SQLITE_ERROR;
19661970
}
19671971
}
1968-
1972+
19691973
c->table = t_ctx;
19701974
if (is_streaming) {
19711975
int rc = stream_callback(vtab->db, c, vector, vsize);
1976+
if (vector_allocated) sqlite3_free((void *)vector);
19721977
if (rc != SQLITE_OK) return rc;
19731978
return vFullScanCursorNext((sqlite3_vtab_cursor *)c); // Position on first row
19741979
}
19751980

19761981
// non-streaming flow
19771982
int k = sqlite3_value_int(argv[3]);
1978-
if (k == 0) return SQLITE_DONE;
1979-
1983+
if (k == 0) {
1984+
if (vector_allocated) sqlite3_free((void *)vector);
1985+
return SQLITE_DONE;
1986+
}
1987+
19801988
if (c->row_count != k) {
19811989
if (c->rowids) sqlite3_free(c->rowids);
19821990
c->rowids = (int64_t *)sqlite3_malloc(k * sizeof(int64_t));
1983-
if (c->rowids == NULL) return SQLITE_NOMEM;
1984-
1991+
if (c->rowids == NULL) {
1992+
if (vector_allocated) sqlite3_free((void *)vector);
1993+
return SQLITE_NOMEM;
1994+
}
1995+
19851996
if (c->distance) sqlite3_free(c->distance);
19861997
c->distance = (double *)sqlite3_malloc(k * sizeof(double));
1987-
if (c->distance == NULL) return SQLITE_NOMEM;
1998+
if (c->distance == NULL) {
1999+
if (vector_allocated) sqlite3_free((void *)vector);
2000+
return SQLITE_NOMEM;
2001+
}
19882002
}
1989-
2003+
19902004
memset(c->rowids, 0, k * sizeof(int64_t));
19912005
for (int i=0; i<k; ++i) c->distance[i] = INFINITY;
1992-
2006+
19932007
c->size = 0;
19942008
c->row_index = 0;
19952009
c->row_count = k;
1996-
2010+
19972011
int rc = run_callback(vtab->db, c, vector, vsize);
2012+
if (vector_allocated) sqlite3_free((void *)vector);
19982013
int count = sort_callback(c);
19992014
c->row_count -= count;
2000-
2015+
20012016
#if 0
20022017
for (int i=0; i<c->row_count; ++i) {
20032018
printf("%lld\t%f\n", (long long)c->rowids[i], c->distance[i]);
20042019
}
20052020
#endif
2006-
2021+
20072022
return rc;
20082023
}
20092024

src/sqlite-vector.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
extern "C" {
2525
#endif
2626

27-
#define SQLITE_VECTOR_VERSION "0.9.91"
27+
#define SQLITE_VECTOR_VERSION "0.9.92"
2828

2929
SQLITE_VECTOR_API int sqlite3_vector_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi);
3030

0 commit comments

Comments
 (0)