| /* AUTORIGHTS |
| Copyright (C) 2007 Princeton University |
| |
| This file is part of Ferret Toolkit. |
| |
| Ferret Toolkit is free software; you can redistribute it and/or modify |
| it under the terms of the GNU General Public License as published by |
| the Free Software Foundation; either version 2, or (at your option) |
| any later version. |
| |
| This program is distributed in the hope that it will be useful, |
| but WITHOUT ANY WARRANTY; without even the implied warranty of |
| MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| GNU General Public License for more details. |
| |
| You should have received a copy of the GNU General Public License |
| along with this program; if not, write to the Free Software Foundation, |
| Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. |
| */ |
| #include <math.h> |
| #include "LSH.h" |
| #include <cass_topk.h> |
| #include "local.h" |
| |
| #define ptb_vec_ge(a,b) ((a)->key >= (b)->key) |
| |
| void LSH_query_init (LSH_query_t *query, LSH_t *lsh, cass_dataset_t *ds, cass_size_t K, cass_size_t L, cass_size_t T) |
| { |
| memset(query, 0, sizeof(*query)); |
| query->lsh = lsh; |
| query->ds = ds; |
| |
| query->K = K; |
| query->L = L; |
| query->T = T; |
| |
| query->bitmap = bitmap_new(lsh->count); |
| |
| query->tmp = type_matrix_alloc(uint32_t, lsh->L, lsh->M); |
| assert(query->tmp != NULL); |
| query->tmp2 = (uint32_t *)malloc(sizeof(uint32_t) * lsh->L); |
| assert(query->tmp2 != NULL); |
| |
| query->ptb = type_matrix_alloc(ptb_vec_t, L, query->lsh->M * 2); |
| assert(query->ptb != NULL); |
| #ifdef QUERY_DIRECT |
| { |
| int i; |
| query->heap = type_calloc(typeof(query->heap), L); |
| assert(query->heap != NULL); |
| for (i = 0; i < L; i++) ARRAY_INIT(query->heap[i]); |
| } |
| #else |
| { |
| ptb_vec_t *scr = gen_score(query->lsh->M); |
| assert(scr != NULL); |
| query->ptb_set = type_calloc(ptb_vec_t, T); |
| assert(query->ptb_set != NULL); |
| gen_perturb_set(scr, query->ptb_set, query->lsh->M, T); |
| query->ptb_vec = type_matrix_alloc(ptb_vec_t, L, T); |
| assert(query->ptb_vec != NULL); |
| query->ptb_step = type_calloc(int, L); |
| assert(query->ptb_step != NULL); |
| free(scr); |
| } |
| #endif |
| query->topk = type_calloc(cass_list_entry_t, K); |
| query->_topk = type_matrix_alloc(cass_list_entry_t, L, K); |
| |
| assert(query->topk != NULL); |
| assert(query->_topk != NULL); |
| |
| query->C = (int*)malloc(L * sizeof(int)); |
| query->H = (int*)malloc(L * sizeof(int)); |
| query->S = (float*)malloc(L * sizeof(float)); |
| |
| assert(query->C != NULL); |
| assert(query->H != NULL); |
| assert(query->S != NULL); |
| |
| query->gamma = 1.0; |
| |
| } |
| |
| void LSH_query_cleanup (LSH_query_t *query) |
| { |
| bitmap_free(&query->bitmap); |
| matrix_free(query->tmp); |
| matrix_free(query->_topk); |
| free(query->topk); |
| free(query->tmp2); |
| #ifdef QUERY_DIRECT |
| { |
| int i; |
| for (i = 0; i < query->L; i++) ARRAY_CLEANUP(query->heap[i]); |
| free(query->heap); |
| } |
| #else |
| free(query->ptb_step); |
| free(query->ptb_set); |
| matrix_free(query->ptb_vec); |
| #endif |
| matrix_free(query->ptb); |
| free(query->C); |
| free(query->H); |
| free(query->S); |
| } |
| |
| void LSH_hash_score (LSH_t *lsh, int L, const float *pnt, uint32_t **hash, ptb_vec_t **ptb) |
| { |
| float s, t; |
| int i, j, k, p, l; |
| |
| l = 0; |
| for (i = 0; i < L; i++) |
| { |
| p = 0; |
| for (j = 0; j < lsh->M; j++) |
| { |
| s = lsh->betas[l]; |
| for (k = 0; k < lsh->D; k++) |
| { |
| s += pnt[k] * lsh->alphas[l][k]; |
| } |
| |
| t = floor(s / lsh->W[i]); |
| hash[i][j] = t; |
| t = s - t * lsh->W[i]; |
| ptb[i][p].set = 1 << j; |
| ptb[i][p].dir = 0; |
| ptb[i][p].key = t * t; |
| p++; |
| t = lsh->W[i] - t; |
| ptb[i][p].set = 1 << j; |
| ptb[i][p].dir = 1 << j; |
| ptb[i][p].key = t * t; |
| p++; |
| l++; |
| } |
| } |
| } |
| |
| void LSH_hash2_noperturb (LSH_t *lsh, uint32_t **hash, uint32_t *hash2, int L) |
| { |
| int i, j; |
| uint32_t h; |
| for (i = 0; i < L; i++) |
| { |
| h = 0; |
| for (j = 0; j < lsh->M; j++) |
| { |
| h += lsh->rnd[i][j] * hash[i][j]; |
| } |
| hash2[i] = h % lsh->H; |
| } |
| } |
| |
| void LSH_hash2_perturb (LSH_t *lsh, uint32_t **hash, uint32_t *hash2, ptb_vec_t *ptb, int l) |
| { |
| uint32_t j, mask; |
| uint32_t set, dir; |
| uint32_t h; |
| set = ptb->set; |
| dir = ptb->dir; |
| mask = 1; |
| h = 0; |
| for (j = 0; j < lsh->M; j++) |
| { |
| if (set & mask) |
| { |
| if (dir & mask) |
| { |
| h += lsh->rnd[l][j] * (hash[l][j] + 1); |
| } |
| else |
| { |
| h += lsh->rnd[l][j] * (hash[l][j] - 1); |
| } |
| } |
| else |
| { |
| h += lsh->rnd[l][j] * hash[l][j]; |
| } |
| mask = mask << 1; |
| } |
| *hash2 = h % lsh->H; |
| } |
| |
| |
| static void LSH_query_local (LSH_query_t *query) |
| { |
| double sx, sy, sxx, sxy; |
| double ld, lk; |
| double a, b; |
| int j; |
| int K = query->K; |
| sx = sy = sxx = sxy = 0.0; |
| for (j = 0; j < K-1; j++) |
| { |
| if (query->topk[K - j - 2].dist >= HUGE_VAL) break; |
| lk = log(j+1); |
| ld = log(query->topk[K - j - 2].dist); |
| sx += lk; |
| sy += ld; |
| sxx += lk * lk; |
| sxy += lk * ld; |
| } |
| least_squares(&a, &b, j, sxx, sxy, sx, sy); |
| a = exp(a); |
| //fprintf(stderr, ">%g\t%g\n", a, b); |
| query->dist = LSH_est(query->lsh->est, a, b, K -1) * query->gamma; |
| // fprintf(stderr, "%g\n", query->dist); |
| } |
| |
| extern void LSH_query_merge (LSH_query_t * query); |
| |
| static void LSH_query_bootstrap (LSH_query_t *query, const float *point) |
| { |
| cass_size_t D = query->lsh->D; |
| cass_size_t K = query->K; |
| cass_size_t L = query->L; |
| LSH_t *lsh = query->lsh; |
| ptb_vec_t **score = query->ptb; |
| uint32_t **tmp = query->tmp; |
| uint32_t *tmp2 = query->tmp2; |
| cass_list_entry_t **_topk = query->_topk; |
| cass_list_entry_t entry; |
| |
| int *C = query->C; |
| int *H = query->H; |
| |
| int i; |
| |
| memset(C, 0, L * sizeof(int)); |
| memset(H, 0, L * sizeof(int)); |
| memset(query->S, 0, L * sizeof(float)); |
| |
| bitmap_clear(query->bitmap); |
| |
| LSH_hash_score(query->lsh, L, point, tmp, score); |
| LSH_hash2_noperturb(query->lsh, tmp, tmp2, L); |
| |
| for (i = 0; i < L; i++) |
| { |
| memset(_topk[i], 0xff, sizeof (*_topk[i]) * K); |
| TOPK_INIT(_topk[i], dist, K, HUGE_VAL); |
| ARRAY_BEGIN_FOREACH(lsh->hash[i].bucket[tmp2[i]], uint32_t id) { |
| if (!bitmap_contain(query->bitmap, id)) |
| { |
| cass_vec_t *vec; |
| bitmap_insert(query->bitmap, id); |
| vec = DATASET_VEC(query->ds, id); |
| entry.id = id; |
| entry.dist = dist_L2_float(D, vec->u.float_data, point); |
| C[i]++; |
| query->CC++; |
| TOPK_INSERT_MIN_UNIQ_DO(_topk[i], dist, id, K, entry, H[i]++); |
| } |
| } |
| ARRAY_END_FOREACH; |
| |
| ptb_qsort(score[i], lsh->M * 2); |
| |
| #ifdef QUERY_DIRECT |
| ARRAY_TRUNC(query->heap[i]); |
| HEAP_ENQUEUE(query->heap[i], score[i][0], ptb_vec_ge); |
| #else |
| query->ptb_step[i] = 0; |
| map_perturb_vector(query->ptb_set, query->ptb_vec[i], score[i], lsh->M, query->T); |
| #endif |
| } |
| |
| if (lsh->est != NULL) |
| { |
| LSH_query_merge(query); |
| LSH_query_local(query); |
| } |
| |
| } |
| |
| static void LSH_query_probe (LSH_query_t *query, const float *point, int l, int g) |
| { |
| cass_size_t D = query->lsh->D; |
| cass_size_t K = query->K; |
| LSH_t *lsh = query->lsh; |
| uint32_t **tmp = query->tmp; |
| #ifdef QUERY_DIRECT |
| cass_size_t M = query->lsh->M; |
| ptb_vec_t *score = query->ptb[l]; |
| #endif |
| cass_list_entry_t *topk = g == 0? query->_topk[l] : query->topk; |
| cass_list_entry_t entry; |
| ptb_vec_t ptb; |
| int *C = query->C; |
| int *H = query->H; |
| uint32_t h; |
| #ifdef QUERY_DIRECT |
| typeof(query->heap) heap = &query->heap[l]; |
| if (HEAP_EMPTY(*heap)) return; |
| ptb = HEAP_HEAD(*heap); |
| HEAP_DEQUEUE(*heap, ptb_vec_ge); |
| #else |
| ptb = query->ptb_vec[l][query->ptb_step[l]++]; |
| #endif |
| LSH_hash2_perturb(query->lsh, tmp, &h, &ptb, l); |
| ARRAY_BEGIN_FOREACH(lsh->hash[l].bucket[h], uint32_t id) { |
| if (!bitmap_contain(query->bitmap, id)) |
| { |
| cass_vec_t *vec; |
| bitmap_insert(query->bitmap, id); |
| vec = DATASET_VEC(query->ds, id); |
| C[l]++; |
| query->CC++; |
| entry.id = id; |
| entry.dist = dist_L2_float(D, vec->u.float_data, point); |
| TOPK_INSERT_MIN_UNIQ_DO(topk, dist, id, K, entry, H[l]++); |
| } |
| } |
| ARRAY_END_FOREACH; |
| #ifdef QUERY_DIRECT |
| { |
| ptb_vec_t ptb2; |
| int m; |
| /* add expand */ |
| m = ptb.max; |
| ptb.max++; |
| if (ptb.max >= M * 2) return; |
| ptb2 = ptb; |
| ptb2.set &= ~score[m].set; |
| ptb2.dir &= ~score[m].dir; |
| ptb2.key -= score[m].key; |
| |
| if ((ptb.set & score[ptb.max].set) == 0) |
| { |
| ptb.set |= score[ptb.max].set; |
| ptb.dir |= score[ptb.max].dir; |
| ptb.key += score[ptb.max].key; |
| HEAP_ENQUEUE(*heap, ptb, ptb_vec_ge); |
| } |
| |
| if ((ptb2.set & score[ptb2.max].set) == 0) |
| { |
| ptb2.set |= score[ptb2.max].set; |
| ptb2.dir |= score[ptb2.max].dir; |
| ptb2.key += score[ptb2.max].key; |
| HEAP_ENQUEUE(*heap, ptb2, ptb_vec_ge); |
| } |
| } |
| #endif |
| } |
| |
| void LSH_query_merge (LSH_query_t *query) |
| { |
| cass_size_t K = query->K; |
| cass_size_t L = query->L; |
| cass_list_entry_t **_topk = query->_topk; |
| cass_list_entry_t *topk = query->topk; |
| int i, j; |
| |
| memset(topk, 0xff, sizeof (*topk) * K); |
| TOPK_INIT(topk, dist, K, HUGE_VAL); |
| |
| // query->CC = 0; |
| |
| for (i = 0; i < L; i++) |
| { |
| for (j = 0; j < K; j++) |
| { |
| TOPK_INSERT_MIN_UNIQ(topk, dist, id, K, _topk[i][j]); |
| } |
| // query->CC += query->C[i]; |
| } |
| } |
| |
| void LSH_query (LSH_query_t *query, const float *point) |
| { |
| LSH_est_t *est = query->lsh->est; |
| cass_list_entry_t *topk = query->topk; |
| float dist = query->dist; |
| |
| cass_size_t T = query->T; |
| cass_size_t L = query->L; |
| int i, j; |
| |
| query->CC = 0; |
| LSH_query_bootstrap(query, point); |
| |
| for (i = 0; i < L; i++) |
| { |
| for (j = 0; j < T; j++) |
| { |
| if (est != NULL) if (topk[0].dist <= dist) return; |
| LSH_query_probe(query, point, i, 0); |
| } |
| } |
| LSH_query_merge(query); |
| } |
| |
| void LSH_query_recall (LSH_query_t *query, const float *point, float R) |
| { |
| LSH_recall_t *recall = &query->lsh->recall; |
| cass_list_entry_t *topk = query->topk; |
| |
| cass_size_t T = query->T; |
| cass_size_t L = query->L; |
| int K = query->K; |
| int i, j; |
| query->CC = 0; |
| LSH_query_bootstrap(query, point); |
| LSH_query_merge(query); |
| for (i = 0; i < T; i++) |
| { |
| if (recall != NULL) |
| { |
| float r = 0; |
| for (j = 0; j < K; j++) |
| { |
| r += LSH_recall(recall, topk[j].dist, i); |
| } |
| r /= K; |
| if (r >= R) return; |
| |
| } |
| for (j = 0; j < L; j++) |
| { |
| LSH_query_probe(query, point, j, 1); |
| } |
| } |
| } |
| |
| |
| int LSH_query_select (LSH_query_t *query, int l) |
| { |
| cass_list_entry_t foo[query->L]; |
| int i; |
| for (i = 0; i < query->L; i++) |
| { |
| //query->S[i] = query->C[i] == 0 ? 0.0 : (float)query->H[i] / (float)query->C[i]; |
| foo[i].id = i; |
| foo[i].dist = query->S[i]; |
| //foo[i].index = query->_topk[i][0].key; |
| } |
| __cass_list_entry_qsort(foo, query->L); |
| /* |
| for (i = 0; i < query->L; i++) |
| { |
| printf("%d\t%g\n", foo[i].index, foo[i].key); |
| } |
| getchar(); |
| */ |
| for (i = 0; i < query->L - 1; i++) |
| { |
| if (rand() % 3 < 2) return foo[i].id; |
| } |
| return foo[i].id; |
| } |
| |
| void LSH_query_boost (LSH_query_t *query, const float *point) |
| { |
| cass_size_t T = query->T; |
| int i, l = 0; |
| query->CC = 0; |
| LSH_query_bootstrap(query, point); |
| for (i = 0; i < T; i++) |
| { |
| l = LSH_query_select(query, l); |
| /* |
| query->C[l] /= 2; |
| query->H[l] /= 2; |
| */ |
| LSH_query_probe(query, point, l, 0); |
| if (query->C[l] > 0) query->S[l] += (1.0 - (double)query->H[l] / (double)query->C[l]) * query->_topk[l][0].dist; |
| else query->S[l] = 0; |
| /* |
| if (query->H[l] == 0) query->S[l] *= 1.5; |
| else query->S[l] = query->_topk[l][0].key; |
| */ |
| } |
| LSH_query_merge(query); |
| } |
| |