Ce este Google JAX? Tot ce trebuie să știți

Google JAX, cunoscut și ca Just After Execution, este un cadru de lucru dezvoltat de Google, menit să accelereze sarcinile din domeniul învățării automate.

Imaginează-ți-l ca pe o bibliotecă Python ce facilitează execuția rapidă a sarcinilor, calculele științifice, transformările de funcții, învățarea profundă, rețelele neuronale și multe altele.

Informații despre Google JAX

Pachetul fundamental pentru calcule în Python este NumPy, ce include funcții esențiale precum agregarea, operațiile vectoriale, algebra liniară, manipularea matricelor n-dimensionale și altele avansate.

Dar dacă am putea accelera și mai mult calculele realizate cu NumPy, în special pentru seturi de date voluminoase?

Nu ar fi ideal un instrument ce ar funcționa la fel de bine pe diferite tipuri de procesoare, precum GPU sau TPU, fără a fi necesare modificări în cod?

Și dacă acest sistem ar putea efectua automat transformări ale funcțiilor, într-un mod compozabil și mai eficient?

Google JAX este o bibliotecă (sau cadru, conform Wikipedia) care exact asta face și chiar mai mult. A fost concepută pentru a optimiza performanța și a executa eficient sarcinile de învățare automată (ML) și învățare profundă. JAX oferă următoarele funcții de transformare, care o diferențiază de alte biblioteci ML, facilitând calcule științifice avansate pentru învățarea profundă și rețele neuronale:

  • Diferențiere automată
  • Vectorizare automată
  • Paralelizare automată
  • Compilare just-in-time (JIT)

Caracteristicile unice ale Google JAX

Toate transformările utilizează XLA (Accelerated Linear Algebra) pentru a obține performanțe superioare și optimizarea memoriei. XLA este un motor de compilare de optimizare specific domeniului, care realizează algebră liniară și accelerează modelele TensorFlow. Utilizarea XLA alături de codul Python nu necesită modificări majore ale acestuia!

Să explorăm mai detaliat fiecare dintre aceste caracteristici.

Caracteristicile Google JAX

Google JAX oferă funcții de transformare compozabile, esențiale pentru îmbunătățirea performanței și realizarea mai eficientă a sarcinilor de învățare profundă. De exemplu, diferențierea automată permite obținerea gradientului unei funcții și calcularea derivatelor de orice ordin. În mod similar, paralelizarea automată și JIT facilitează efectuarea mai multor sarcini în paralel. Aceste transformări sunt cruciale pentru aplicații precum robotica, jocurile și cercetarea.

O funcție de transformare compozabilă este o funcție pură care modifică un set de date într-o altă formă. Sunt denumite compozabile deoarece sunt autonome (adică, nu depind de restul programului) și apatride (aceeași intrare va produce întotdeauna aceeași ieșire).

Y(x) = T: (f(x))

În ecuația de mai sus, f(x) este funcția inițială, asupra căreia se aplică o transformare. Y(x) este funcția rezultată după aplicarea transformării.

De exemplu, dacă ai o funcție numită „total_bill_amt” și vrei rezultatul ca o transformare a funcției, poți folosi simplu transformarea dorită, să zicem gradient (grad):

grad_total_bill = grad(total_bill_amt)

Prin transformarea funcțiilor numerice cu funcții precum grad(), putem obține cu ușurință derivatele lor de ordin superior, utile în algoritmii de optimizare a învățării profunde, cum ar fi coborârea gradientului, accelerând și eficientizând astfel algoritmii. Similar, cu jit(), putem compila programe Python just-in-time (lazy).

#1. Diferențierea automată

Python utilizează funcția autograd pentru a diferenția automat codul NumPy și codul nativ Python. JAX folosește o versiune modificată a autograd (grad) și o combină cu XLA (Algebra liniară accelerată) pentru a realiza diferențierea automată și a calcula derivate de orice ordin pentru GPU-uri (Unități de procesare grafică) și TPU-uri (Unități de procesare tensorială).

O scurtă notă despre TPU, GPU și CPU: CPU sau unitatea centrală de procesare gestionează toate operațiile computerului. GPU este un procesor adițional care sporește capacitatea de calcul și rulează operații complexe. TPU este o unitate puternică, dezvoltată special pentru sarcini dificile și grele, precum AI și algoritmii de învățare profundă.

Similar funcției autograd, care diferențiază prin bucle, recursiuni, ramuri etc., JAX utilizează funcția grad() pentru a calcula gradienții în mod invers (backpropagation). Putem de asemenea diferenția o funcție de orice ordin folosind grad:

grad(grad(grad(sin θ))) (1.0)

Diferențierea automată de ordin superior

După cum am menționat anterior, grad este utilă pentru a găsi derivatele parțiale ale unei funcții. Derivata parțială poate fi folosită pentru a calcula coborârea gradientului unei funcții de cost în raport cu parametrii rețelei neuronale în învățarea profundă, cu scopul de a minimiza pierderile.

Calculul derivatei parțiale

Să presupunem că o funcție are mai multe variabile, x, y și z. Calcularea derivatei unei variabile, menținând constante celelalte variabile, se numește derivată parțială. Presupunem că avem o funcție:

f(x,y,z) = x + 2y + z2

Exemplu pentru a ilustra derivata parțială

Derivata parțială a lui x va fi ∂f/∂x, indicând cum se modifică o funcție pentru o variabilă când celelalte sunt constante. Dacă am face acest lucru manual, ar trebui să scriem un program de diferențiere, să îl aplicăm fiecărei variabile și apoi să calculăm coborârea gradientului. Aceasta ar deveni o operațiune complexă și consumatoare de timp pentru mai multe variabile.

Diferențierea automată descompune funcția într-un set de operații elementare, precum +, -, *, / sau sin, cos, tan, exp etc., apoi aplică regula lanțului pentru a calcula derivata. Acest lucru poate fi realizat atât în modul direct, cât și în modul invers.

Și asta nu este tot! Toate aceste calcule se întâmplă foarte rapid (imaginează-ți un milion de calcule similare cu cele de mai sus și timpul necesar!). XLA se ocupă de viteză și performanță.

#2. Algebră liniară accelerată

Să luăm ecuația anterioară. Fără XLA, calculul ar dura trei (sau mai multe) nuclee, fiecare nucleu îndeplinind o sarcină mai mică. De exemplu:

Kernel k1 -> x * 2y (înmulțire)

k2 -> x * 2y + z (adunare)

k3 -> Reducere

Dacă aceeași sarcină este realizată de XLA, un singur nucleu gestionează toate operațiile intermediare prin fuzionarea lor. Rezultatele intermediare ale operațiilor elementare sunt transmise în flux în loc să fie stocate în memorie, economisind memorie și sporind viteza.

#3. Compilare just-in-time

JAX utilizează intern compilatorul XLA pentru a crește viteza de execuție. XLA poate accelera atât CPU, cât și GPU și TPU. Toate acestea sunt posibile datorită execuției codului JIT. Pentru a folosi această funcție, putem importa jit:

from jax import jit
def my_function(x):
	…………some lines of code
my_function_jit = jit(my_function)

O altă metodă este decorarea jit deasupra definiției funcției:

@jit
def my_function(x):
	…………some lines of code

Acest cod este mai rapid, deoarece transformarea va returna versiunea compilată a codului apelantului, în loc să folosească interpretorul Python. Este util mai ales pentru intrări vectoriale, precum matrice și matrici.

Același lucru este valabil și pentru funcțiile Python existente, de exemplu cele din pachetul NumPy. În acest caz, ar trebui să importăm jax.numpy ca jnp, în loc de NumPy:

import jax
import jax.numpy as jnp

x = jnp.array([[1,2,3,4], [5,6,7,8]])

După import, obiectul matrice JAX de bază, denumit DeviceArray, înlocuiește matricea standard NumPy. DeviceArray este „leneș” – valorile sunt păstrate în accelerator până când sunt necesare. Aceasta înseamnă că programul JAX nu așteaptă ca rezultatele să revină la programul apelant (Python), urmând o expediere asincronă.

#4. Vectorizare automată (vmap)

În lumea învățării automate, seturile de date pot include milioane sau mai multe puncte de date. Cel mai probabil, vom efectua calcule sau manipulări pe fiecare sau pe majoritatea acestor puncte, o sarcină ce consumă mult timp și memorie! De exemplu, pentru a calcula pătratul fiecărui punct de date, primul lucru la care ne gândim este să creăm o buclă și să calculăm pătratul unul câte unul – complicat!

Dacă creăm aceste puncte ca vectori, am putea calcula toate pătratele dintr-o singură operație, folosind manipulări vectoriale sau matriceale cu NumPy. Dar dacă programul ar putea face asta automat? JAX exact asta face! Vectorizează automat toate punctele de date, permițând efectuarea ușoară a oricăror operații, făcând algoritmii mai rapizi și mai eficienți.

JAX folosește funcția vmap pentru auto-vectorizare. Să luăm în considerare următoarea matrice:

x = jnp.array([1,2,3,4,5,6,7,8,9,10])
y = jnp.square(x)

Făcând doar cele de mai sus, metoda pătratului se va executa pentru fiecare punct din matrice. Dar dacă facem următoarele:

vmap(jnp.square(x))

Metoda pătratului se va executa o singură dată, deoarece punctele de date sunt acum vectorizate automat cu metoda vmap înainte de executarea funcției, iar bucla este mutată la nivelul elementar de operare, rezultând o multiplicare matriceală în locul unei multiplicări scalare, ceea ce îmbunătățește performanța.

#5. Programare SPMD (pmap)

SPMD (Single Program Multiple Data) este esențială în contextul învățării profunde – adesea vom aplica aceleași funcții pe diferite seturi de date aflate pe mai multe GPU-uri sau TPU-uri. JAX are o funcție numită pmap, ce permite programarea în paralel pe mai multe GPU-uri sau orice accelerator. Similar lui JIT, programele care folosesc pmap vor fi compilate de XLA și executate simultan în toate sistemele. Această paralelizare automată funcționează atât pentru calculele directe, cât și pentru cele inverse.

Putem aplica, de asemenea, transformări multiple dintr-o singură mișcare, în orice ordine, pe orice funcție, ca:

pmap(vmap(jit(grad (f(x)))))

Transformări compozabile multiple

Limitările Google JAX

Dezvoltatorii Google JAX s-au concentrat pe accelerarea algoritmilor de învățare profundă, introducând aceste transformări extraordinare. Funcțiile și pachetele de calcul științific sunt similare NumPy, astfel încât curba de învățare este ușoară. Cu toate acestea, JAX are câteva limitări:

  • Google JAX este încă în dezvoltare timpurie și, deși scopul său principal este optimizarea performanței, nu aduce prea multe beneficii pentru calculul CPU. NumPy pare să funcționeze mai bine, iar utilizarea JAX poate crește suprasarcina.
  • JAX este încă în stadiu de cercetare sau incipient și necesită mai multe ajustări pentru a atinge standardele de infrastructură ale cadrelor precum TensorFlow, care sunt mai stabile și au mai multe modele predefinite, proiecte open-source și materiale de învățare.
  • În prezent, JAX nu este compatibil cu sistemul de operare Windows – este necesară o mașină virtuală pentru a-l utiliza.
  • JAX funcționează doar pe funcții pure, cele care nu au efecte secundare. Pentru funcții cu efecte secundare, JAX s-ar putea să nu fie o opțiune bună.

Cum să instalezi JAX în mediul tău Python

Dacă ai o configurare Python pe sistem și vrei să rulezi JAX pe mașina locală (CPU), folosește următoarele comenzi:

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

Dacă dorești să rulezi Google JAX pe GPU sau TPU, urmează instrucțiunile de pe pagina GitHub JAX. Pentru a configura Python, vizitează pagina oficială de descărcări Python.

Concluzie

Google JAX este excelent pentru crearea algoritmilor eficienți de învățare profundă, robotică și cercetare. În ciuda limitărilor, este folosit pe scară largă cu alte cadre precum Haiku, Flax și altele. Vei putea aprecia capacitățile JAX când vei rula programe și vei observa diferențele de timp în executarea codului cu și fără JAX. Poți începe prin a citi documentația oficială Google JAX, care este destul de cuprinzătoare.