blob: 9bb8aeaa4608f381caf8af1e2fcef8508ae0b7ac [file] [log] [blame]
/* AUTORIGHTS
Copyright (C) 2007 Princeton University
This file is part of Ferret Toolkit.
Ferret Toolkit is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2, or (at your option)
any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software Foundation,
Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
#include <math.h>
#include "LSH.h"
#include <cass_topk.h>
#include "local.h"
#define ptb_vec_ge(a,b) ((a)->key >= (b)->key)
void LSH_query_init (LSH_query_t *query, LSH_t *lsh, cass_dataset_t *ds, cass_size_t K, cass_size_t L, cass_size_t T)
{
memset(query, 0, sizeof(*query));
query->lsh = lsh;
query->ds = ds;
query->K = K;
query->L = L;
query->T = T;
query->bitmap = bitmap_new(lsh->count);
query->tmp = type_matrix_alloc(uint32_t, lsh->L, lsh->M);
assert(query->tmp != NULL);
query->tmp2 = (uint32_t *)malloc(sizeof(uint32_t) * lsh->L);
assert(query->tmp2 != NULL);
query->ptb = type_matrix_alloc(ptb_vec_t, L, query->lsh->M * 2);
assert(query->ptb != NULL);
#ifdef QUERY_DIRECT
{
int i;
query->heap = type_calloc(typeof(query->heap), L);
assert(query->heap != NULL);
for (i = 0; i < L; i++) ARRAY_INIT(query->heap[i]);
}
#else
{
ptb_vec_t *scr = gen_score(query->lsh->M);
assert(scr != NULL);
query->ptb_set = type_calloc(ptb_vec_t, T);
assert(query->ptb_set != NULL);
gen_perturb_set(scr, query->ptb_set, query->lsh->M, T);
query->ptb_vec = type_matrix_alloc(ptb_vec_t, L, T);
assert(query->ptb_vec != NULL);
query->ptb_step = type_calloc(int, L);
assert(query->ptb_step != NULL);
free(scr);
}
#endif
query->topk = type_calloc(cass_list_entry_t, K);
query->_topk = type_matrix_alloc(cass_list_entry_t, L, K);
assert(query->topk != NULL);
assert(query->_topk != NULL);
query->C = (int*)malloc(L * sizeof(int));
query->H = (int*)malloc(L * sizeof(int));
query->S = (float*)malloc(L * sizeof(float));
assert(query->C != NULL);
assert(query->H != NULL);
assert(query->S != NULL);
query->gamma = 1.0;
}
void LSH_query_cleanup (LSH_query_t *query)
{
bitmap_free(&query->bitmap);
matrix_free(query->tmp);
matrix_free(query->_topk);
free(query->topk);
free(query->tmp2);
#ifdef QUERY_DIRECT
{
int i;
for (i = 0; i < query->L; i++) ARRAY_CLEANUP(query->heap[i]);
free(query->heap);
}
#else
free(query->ptb_step);
free(query->ptb_set);
matrix_free(query->ptb_vec);
#endif
matrix_free(query->ptb);
free(query->C);
free(query->H);
free(query->S);
}
void LSH_hash_score (LSH_t *lsh, int L, const float *pnt, uint32_t **hash, ptb_vec_t **ptb)
{
float s, t;
int i, j, k, p, l;
l = 0;
for (i = 0; i < L; i++)
{
p = 0;
for (j = 0; j < lsh->M; j++)
{
s = lsh->betas[l];
for (k = 0; k < lsh->D; k++)
{
s += pnt[k] * lsh->alphas[l][k];
}
t = floor(s / lsh->W[i]);
hash[i][j] = t;
t = s - t * lsh->W[i];
ptb[i][p].set = 1 << j;
ptb[i][p].dir = 0;
ptb[i][p].key = t * t;
p++;
t = lsh->W[i] - t;
ptb[i][p].set = 1 << j;
ptb[i][p].dir = 1 << j;
ptb[i][p].key = t * t;
p++;
l++;
}
}
}
void LSH_hash2_noperturb (LSH_t *lsh, uint32_t **hash, uint32_t *hash2, int L)
{
int i, j;
uint32_t h;
for (i = 0; i < L; i++)
{
h = 0;
for (j = 0; j < lsh->M; j++)
{
h += lsh->rnd[i][j] * hash[i][j];
}
hash2[i] = h % lsh->H;
}
}
void LSH_hash2_perturb (LSH_t *lsh, uint32_t **hash, uint32_t *hash2, ptb_vec_t *ptb, int l)
{
uint32_t j, mask;
uint32_t set, dir;
uint32_t h;
set = ptb->set;
dir = ptb->dir;
mask = 1;
h = 0;
for (j = 0; j < lsh->M; j++)
{
if (set & mask)
{
if (dir & mask)
{
h += lsh->rnd[l][j] * (hash[l][j] + 1);
}
else
{
h += lsh->rnd[l][j] * (hash[l][j] - 1);
}
}
else
{
h += lsh->rnd[l][j] * hash[l][j];
}
mask = mask << 1;
}
*hash2 = h % lsh->H;
}
static void LSH_query_local (LSH_query_t *query)
{
double sx, sy, sxx, sxy;
double ld, lk;
double a, b;
int j;
int K = query->K;
sx = sy = sxx = sxy = 0.0;
for (j = 0; j < K-1; j++)
{
if (query->topk[K - j - 2].dist >= HUGE_VAL) break;
lk = log(j+1);
ld = log(query->topk[K - j - 2].dist);
sx += lk;
sy += ld;
sxx += lk * lk;
sxy += lk * ld;
}
least_squares(&a, &b, j, sxx, sxy, sx, sy);
a = exp(a);
//fprintf(stderr, ">%g\t%g\n", a, b);
query->dist = LSH_est(query->lsh->est, a, b, K -1) * query->gamma;
// fprintf(stderr, "%g\n", query->dist);
}
extern void LSH_query_merge (LSH_query_t * query);
static void LSH_query_bootstrap (LSH_query_t *query, const float *point)
{
cass_size_t D = query->lsh->D;
cass_size_t K = query->K;
cass_size_t L = query->L;
LSH_t *lsh = query->lsh;
ptb_vec_t **score = query->ptb;
uint32_t **tmp = query->tmp;
uint32_t *tmp2 = query->tmp2;
cass_list_entry_t **_topk = query->_topk;
cass_list_entry_t entry;
int *C = query->C;
int *H = query->H;
int i;
memset(C, 0, L * sizeof(int));
memset(H, 0, L * sizeof(int));
memset(query->S, 0, L * sizeof(float));
bitmap_clear(query->bitmap);
LSH_hash_score(query->lsh, L, point, tmp, score);
LSH_hash2_noperturb(query->lsh, tmp, tmp2, L);
for (i = 0; i < L; i++)
{
memset(_topk[i], 0xff, sizeof (*_topk[i]) * K);
TOPK_INIT(_topk[i], dist, K, HUGE_VAL);
ARRAY_BEGIN_FOREACH(lsh->hash[i].bucket[tmp2[i]], uint32_t id) {
if (!bitmap_contain(query->bitmap, id))
{
cass_vec_t *vec;
bitmap_insert(query->bitmap, id);
vec = DATASET_VEC(query->ds, id);
entry.id = id;
entry.dist = dist_L2_float(D, vec->u.float_data, point);
C[i]++;
query->CC++;
TOPK_INSERT_MIN_UNIQ_DO(_topk[i], dist, id, K, entry, H[i]++);
}
}
ARRAY_END_FOREACH;
ptb_qsort(score[i], lsh->M * 2);
#ifdef QUERY_DIRECT
ARRAY_TRUNC(query->heap[i]);
HEAP_ENQUEUE(query->heap[i], score[i][0], ptb_vec_ge);
#else
query->ptb_step[i] = 0;
map_perturb_vector(query->ptb_set, query->ptb_vec[i], score[i], lsh->M, query->T);
#endif
}
if (lsh->est != NULL)
{
LSH_query_merge(query);
LSH_query_local(query);
}
}
static void LSH_query_probe (LSH_query_t *query, const float *point, int l, int g)
{
cass_size_t D = query->lsh->D;
cass_size_t K = query->K;
LSH_t *lsh = query->lsh;
uint32_t **tmp = query->tmp;
#ifdef QUERY_DIRECT
cass_size_t M = query->lsh->M;
ptb_vec_t *score = query->ptb[l];
#endif
cass_list_entry_t *topk = g == 0? query->_topk[l] : query->topk;
cass_list_entry_t entry;
ptb_vec_t ptb;
int *C = query->C;
int *H = query->H;
uint32_t h;
#ifdef QUERY_DIRECT
typeof(query->heap) heap = &query->heap[l];
if (HEAP_EMPTY(*heap)) return;
ptb = HEAP_HEAD(*heap);
HEAP_DEQUEUE(*heap, ptb_vec_ge);
#else
ptb = query->ptb_vec[l][query->ptb_step[l]++];
#endif
LSH_hash2_perturb(query->lsh, tmp, &h, &ptb, l);
ARRAY_BEGIN_FOREACH(lsh->hash[l].bucket[h], uint32_t id) {
if (!bitmap_contain(query->bitmap, id))
{
cass_vec_t *vec;
bitmap_insert(query->bitmap, id);
vec = DATASET_VEC(query->ds, id);
C[l]++;
query->CC++;
entry.id = id;
entry.dist = dist_L2_float(D, vec->u.float_data, point);
TOPK_INSERT_MIN_UNIQ_DO(topk, dist, id, K, entry, H[l]++);
}
}
ARRAY_END_FOREACH;
#ifdef QUERY_DIRECT
{
ptb_vec_t ptb2;
int m;
/* add expand */
m = ptb.max;
ptb.max++;
if (ptb.max >= M * 2) return;
ptb2 = ptb;
ptb2.set &= ~score[m].set;
ptb2.dir &= ~score[m].dir;
ptb2.key -= score[m].key;
if ((ptb.set & score[ptb.max].set) == 0)
{
ptb.set |= score[ptb.max].set;
ptb.dir |= score[ptb.max].dir;
ptb.key += score[ptb.max].key;
HEAP_ENQUEUE(*heap, ptb, ptb_vec_ge);
}
if ((ptb2.set & score[ptb2.max].set) == 0)
{
ptb2.set |= score[ptb2.max].set;
ptb2.dir |= score[ptb2.max].dir;
ptb2.key += score[ptb2.max].key;
HEAP_ENQUEUE(*heap, ptb2, ptb_vec_ge);
}
}
#endif
}
void LSH_query_merge (LSH_query_t *query)
{
cass_size_t K = query->K;
cass_size_t L = query->L;
cass_list_entry_t **_topk = query->_topk;
cass_list_entry_t *topk = query->topk;
int i, j;
memset(topk, 0xff, sizeof (*topk) * K);
TOPK_INIT(topk, dist, K, HUGE_VAL);
// query->CC = 0;
for (i = 0; i < L; i++)
{
for (j = 0; j < K; j++)
{
TOPK_INSERT_MIN_UNIQ(topk, dist, id, K, _topk[i][j]);
}
// query->CC += query->C[i];
}
}
void LSH_query (LSH_query_t *query, const float *point)
{
LSH_est_t *est = query->lsh->est;
cass_list_entry_t *topk = query->topk;
float dist = query->dist;
cass_size_t T = query->T;
cass_size_t L = query->L;
int i, j;
query->CC = 0;
LSH_query_bootstrap(query, point);
for (i = 0; i < L; i++)
{
for (j = 0; j < T; j++)
{
if (est != NULL) if (topk[0].dist <= dist) return;
LSH_query_probe(query, point, i, 0);
}
}
LSH_query_merge(query);
}
void LSH_query_recall (LSH_query_t *query, const float *point, float R)
{
LSH_recall_t *recall = &query->lsh->recall;
cass_list_entry_t *topk = query->topk;
cass_size_t T = query->T;
cass_size_t L = query->L;
int K = query->K;
int i, j;
query->CC = 0;
LSH_query_bootstrap(query, point);
LSH_query_merge(query);
for (i = 0; i < T; i++)
{
if (recall != NULL)
{
float r = 0;
for (j = 0; j < K; j++)
{
r += LSH_recall(recall, topk[j].dist, i);
}
r /= K;
if (r >= R) return;
}
for (j = 0; j < L; j++)
{
LSH_query_probe(query, point, j, 1);
}
}
}
int LSH_query_select (LSH_query_t *query, int l)
{
cass_list_entry_t foo[query->L];
int i;
for (i = 0; i < query->L; i++)
{
//query->S[i] = query->C[i] == 0 ? 0.0 : (float)query->H[i] / (float)query->C[i];
foo[i].id = i;
foo[i].dist = query->S[i];
//foo[i].index = query->_topk[i][0].key;
}
__cass_list_entry_qsort(foo, query->L);
/*
for (i = 0; i < query->L; i++)
{
printf("%d\t%g\n", foo[i].index, foo[i].key);
}
getchar();
*/
for (i = 0; i < query->L - 1; i++)
{
if (rand() % 3 < 2) return foo[i].id;
}
return foo[i].id;
}
void LSH_query_boost (LSH_query_t *query, const float *point)
{
cass_size_t T = query->T;
int i, l = 0;
query->CC = 0;
LSH_query_bootstrap(query, point);
for (i = 0; i < T; i++)
{
l = LSH_query_select(query, l);
/*
query->C[l] /= 2;
query->H[l] /= 2;
*/
LSH_query_probe(query, point, l, 0);
if (query->C[l] > 0) query->S[l] += (1.0 - (double)query->H[l] / (double)query->C[l]) * query->_topk[l][0].dist;
else query->S[l] = 0;
/*
if (query->H[l] == 0) query->S[l] *= 1.5;
else query->S[l] = query->_topk[l][0].key;
*/
}
LSH_query_merge(query);
}