Questa è una difficoltà concettuale comune quando si impara ad usare efficacemente NumPy . Normalmente, l'elaborazione dei dati in Python è espressa al meglio in termini di iteratori , per mantenere basso l'utilizzo della memoria, per massimizzare le opportunità di parallelismo con il sistema I / O e per il riutilizzo e la combinazione di parti di algoritmi .
Ma NumPy capovolge tutto questo: l'approccio migliore è esprimere l'algoritmo come una sequenza di operazioni dell'intero array , per ridurre al minimo la quantità di tempo trascorso nell'interprete Python lento e massimizzare il tempo quantità di tempo trascorso in routine NumPy veloci compilate.
Ecco l'approccio generale che prendo:
-
Conserva la versione originale della funzione (che sei sicuro sia corretta) in modo che tu possa testarla con le tue versioni migliorate sia per correttezza che per velocità.
-
Lavora dall'interno verso l'esterno: cioè, inizia con il ciclo più interno e vedi se può essere vettorializzato; poi, quando hai finito, sposta un livello e continua.
-
Trascorri molto tempo a leggere la documentazione di NumPy . Ci sono un sacco di funzioni e operazioni in là e non sono sempre brillantemente nominati, quindi vale la pena conoscerli. In particolare, se ti trovi a pensare, "se solo ci fosse una funzione che ha fatto questo e così", allora vale la pena spendere dieci minuti per cercarlo. Di solito è lì da qualche parte.
Non c'è alcun sostituto per la pratica, quindi ho intenzione di darti alcuni problemi di esempio. L'obiettivo di ogni problema è di riscrivere la funzione in modo che sia completamente vettorializzato : cioè, in modo che esso sia costituito da una sequenza di operazioni NumPy su interi array, senza loop Python nativi (no for
o while
istruzioni, nessun iteratore o comprensione).
Problema 1
def sumproducts(x, y):
"""Return the sum of x[i] * y[j] for all pairs of indices i, j.
>>> sumproducts(np.arange(3000), np.arange(3000))
20236502250000
"""
result = 0
for i in range(len(x)):
for j in range(len(y)):
result += x[i] * y[j]
return result
Problema 2
def countlower(x, y):
"""Return the number of pairs i, j such that x[i] < y[j].
>>> countlower(np.arange(0, 200, 2), np.arange(40, 140))
4500
"""
result = 0
for i in range(len(x)):
for j in range(len(y)):
if x[i] < y[j]:
result += 1
return result
Problema 3
def cleanup(x, missing=-1, value=0):
"""Return an array that's the same as x, except that where x ==
missing, it has value instead.
>>> cleanup(np.arange(-3, 3), value=10)
... # doctest: +NORMALIZE_WHITESPACE
array([-3, -2, 10, 0, 1, 2])
"""
result = []
for i in range(len(x)):
if x[i] == missing:
result.append(value)
else:
result.append(x[i])
return np.array(result)
Spoiler qui sotto. Otterrai i migliori risultati se ti avventuri prima di guardare le mie soluzioni!
Risposta 1
np.sum(x) * np.sum(y)
Risposta 2
np.sum(np.searchsorted(np.sort(x), y))
Risposta 3
np.where(x == missing, value, x)