Il modo più veloce per trovare N vettori più vicini per il vettore X nella lista?

4

Ho un mazzo (~ 20.000) di grandi (~ 200 dimensioni) vettori in una lista non ordinata. Posso creare un nuovo vettore della stessa dimensione, e mi piacerebbe trovare i migliori vettori esistenti N (di solito 10 o più) più vicini (definiti dalla similarità del coseno) dalla mia lista. In questo momento il mio approccio è quello di generare una lista di differenze vettoriali e poi ordinarle e prendere la top 10, ma ci vuole un po 'di tempo (2 o 3 secondi) per confronto, che va bene una alla volta ma si accumula davvero quando ho bisogno fare un gruppo in una volta.

Sono aperto anche alla pre-elaborazione della mia lista di vettori (dico solo che è non ordinata poiché non so se sia possibile ordinare una lista di vettori n-dimensionali) se ciò fosse di aiuto. Saranno sempre della stessa dimensione del vettore. Per il contesto, i vettori sono i risultati di word2vec.

    
posta IronWaffleMan 26.03.2018 - 02:33
fonte

1 risposta

1

Ho avuto lo stesso problema: trovare il n più piccolo di N interi. Si noti che per trovare quello più piccolo è necessario il confronto N, quindi la mia soluzione utilizza i confronti n * N per trovare il n più piccolo, piuttosto che N ^ 2 per un semplice ordinamento a bolle. Ecco il codice C, insieme a un semplice driver di test che ho scritto per questo post ...

/* --- standard headers --- */
#include <stdio.h>
#include <stdlib.h>

/*===========================================================================
 * Function: nsmallest ( n, nx, x )
 * Purpose:  finds the n smallest values in x[nx], returning their indexes
 * --------------------------------------------------------------------------
 * Arguments: n (I) int containing number of smallest x[nx]'s
 *                  whose indexes are to be returned
 *           nx (I) int containing number of values in x[nx]
 *           x (I)  int* containing nx values, the indexes
 *                  of whose smallest n values are to be returned
 * Returns: (int *) list of indexes containing the smallest
 *                  n values in x[nx].
 * --------------------------------------------------------------------------
 * Notes:     o
 *=========================================================================*/
int *nsmallest ( int n, int nx, int *x ) {
  static int indexes[999];  /* returned indexes of n smallest x[nx]'s */
  int   ix = 0,         /* x[] index */
    index=0, jndex=0,   /* indexes[] indexes */
    nindexes = 1;       /* number of smallest x[]'s found so far */
  indexes[0] = 0;       /* init indexes[] list with first x[] */
  for ( ix=1; ix<nx; ix++ ) {   /* search for n smallest x[nx]'s */
    for ( index=0; index<nindexes; index++ ) { /* compare x[ix] to indexes[] */
      if ( x[ix] < x[indexes[index]] ) { /* put ix before indexes[index] */
        for ( jndex=nindexes-1; jndex>=index; jndex-- ) /* work backwards */
          indexes[jndex+1] = indexes[jndex]; /* move each indexes[] "down" */
        indexes[index] = ix;    /* put current ix in now-vacant slot */
        break;          /* no need for further comparisons */
        } /* --- end-of-if(x[ix]<x[indexes[index]]) --- */
      } /* --- end-of-for(index) --- */
    if ( nindexes < n ) {   /* still need more smallest x[nx]'s */
      if ( index >= nindexes ) indexes[nindexes] = ix; /* ix in last slot */
      nindexes++; }     /* count another smallest x[nx] */
    } /* --- end-of-for(ix) --- */
  return ( indexes );       /* indexes of n smallest x[nx]'s to caller */
  } /* --- end-of-function nsmallest() --- */

#ifdef TESTDRIVE
int main ( int argc, char *argv[] ) {
  int   n     = ( argc>1? atoi(argv[1]) : 10 ),
    nx    = ( argc>2? atoi(argv[2]) : 9999 ),
    seed  = ( argc>3? atoi(argv[3]) : 987654321 );
  double xmax = ( argc>4? (double)atoi(argv[4]) : 999999.0 );
  int   x[99999], ix=0, *indexes=NULL;
  srand(seed);
  for ( ix=0; ix<nx; ix++ )
    x[ix] = (int)( xmax*((double)rand())/((double)RAND_MAX) );
  indexes = nsmallest(n,nx,x);
  printf("%d smallest x[%d]'s...\n",n,nx);
  for ( ix=0; ix<n; ix++ )
    printf("  %d) x[%d] = %d\n", ix+1,indexes[ix],x[indexes[ix]]);
  exit ( 0 );
  } /* --- end-of-main() nsmallest test driver --- */
#endif
/* ------------------------ end-of-file nsmallest.c ---------------------- */

Modifica Ho scritto rapidamente quanto sopra molto tempo fa per i miei scopi, che non erano particolarmente critici dal punto di vista del tempo, quindi il codice postato andava bene per me. Ma dopo averlo postato e guardato di nuovo, ho notato che il ciclo "index" si sposta dal più piccolo al più grande, e significa che deve passare attraverso l'intero ciclo per ogni numero candidato prima di poter scartare quel numero.

Quindi, soprattutto per i calci, l'ho riscritto passando dal più grande al più piccolo. Quindi un candidato può essere scartato immediatamente se è già più grande del numero più piccolo nella lista. E anche io (anche se questo era solo un leggero miglioramento) ho sostituito il ciclo "jndex", che "fa spazio" per un piccolo numero appena trovato, da un singolo memmove ().

E ora, testando i primi 150 numeri su 999000 (ho aumentato la dimensione dell'array x [] per questo test), il tempo è passato da 0,337 secondi a 0,012 secondi. Fondamentalmente, tutti i numeri vengono scartati immediatamente, poiché è raro imbattersi in un numero candidato più piccolo di quelli già più piccoli. Quindi stai solo facendo un po 'più di N confronti, molto meno del precedente n * N, per trovare i n numeri più piccoli.

    
risposta data 26.03.2018 - 22:32
fonte

Leggi altre domande sui tag