嘿,Python愛好者!您是否希望以超音速速度運行Numpy代碼?認識JAX!。您在機器學習,深度學習和數字計算過程中的新最好的朋友。將其視為具有超能力的Numpy。它可以自動處理梯度,編譯代碼以使用JIT快速運行,甚至可以在GPU和TPU上運行而不會破壞汗水。無論您是構建神經網絡,處理科學數據,調整變壓器模型,還是只是試圖加快計算速度,JAX都會支持您。讓我們深入研究,看看是什麼使Jax如此特別。
本指南詳細介紹了JAX及其生態系統。
本文作為數據科學博客馬拉鬆的一部分發表。
根據官方文檔,JAX是用於加速陣列計算和程序轉換的Python庫,專為高性能數值計算和大規模機器學習而設計。因此,JAX本質上是類固醇上的Numpy,它將熟悉的Numpy風格操作與自動分化和硬件加速相結合。可以將其視為獲得三個世界中最好的。
設定JAX的是其轉變。這些是可以修改您的Python代碼的強大功能:
這是一個快速的外觀:
導入jax.numpy作為jnp 來自Jax Import Grad,Jit #定義一個簡單的功能 @Jit#用編譯加快速度 def square_sum(x): 返回JNP.Sum(JNP.Square(x)) #自動獲取其梯度功能 gradient_fn = grad(square_sum) #嘗試一下 x = jnp.Array([1.0,2.0,3.0]) 打印(f“漸變:{gradient_fn(x)}”)
輸出:
漸變:[2。 4。6。]]
在下面,我們將遵循一些步驟以開始使用JAX。
設置jax非常適合僅使用CPU。您可以使用JAX文檔以獲取更多信息。
為您的項目創建CONDA環境
#為JAX創建Conda Env $ conda create -name jaxdev python = 3.11 #激活Env $ conda激活jaxdev #創建一個項目dir name jax101 $ MKDIR JAX101 #進入DIR $ CD JAX101
在新創建的環境中安裝JAX
#僅適用於CPU PIP安裝 - 升級PIP PIP安裝 - 升級“ JAX” #對於GPU PIP安裝 - 升級PIP PIP安裝 - 升級“ JAX [CUDA12]”
現在,您準備深入研究真實的事物。在實用編碼上弄髒您的手之前,讓我們學習一些新概念。我將首先解釋這些概念,然後我們將共同編碼以了解實際的觀點。
首先,順其自然,為什麼我們再次學習新圖書館?我將在本指南中以盡可能簡單的方式回答這個問題。
將JAX視為電動工具。儘管Numpy就像是可靠的手鋸,但Jax就像現代的電鋸。它需要更多的步驟和知識,但是對於密集的計算任務而言,性能好處是值得的。
在下一節中,我們將深入研究Jax的轉換,從JIT彙編開始。這些轉變是賦予其超級大國的Jax的原因,而理解它們是有效利用JAX的關鍵。
JAX的轉換是真正將其與數值計算庫(例如Numpy或Scipy)區分開來的。讓我們探索每個人,看看它們如何增強您的代碼。
Just-Amper Ampilation通過在運行時(而不是提前編制程序)來優化代碼執行。
在JAX中,JAX.JIT將Python函數轉換為JIT編譯版本。用 @jax.jit裝飾功能可捕獲其執行圖,優化它並使用XLA對其進行編譯。然後,編譯的版本執行,提供了重大的加速,尤其是對於重複的功能調用。
這是您可以嘗試的方法。
導入jax.numpy作為jnp 來自JAX Import Jit 進口時間 #計算密集型功能 def slow_function(x): 對於_範圍(1000): x = jnp.sin(x)jnp.cos(x) 返回x #與JIT相同的功能 @Jit def fast_function(x): 對於_範圍(1000): x = jnp.sin(x)jnp.cos(x) 返回x
這是相同的功能,一個只是一個普通的python彙編過程,另一個函數用作JAX的JIT彙編過程。它將計算正弦和余弦函數的1000個數據點總和。我們將使用時間比較性能。
#比較性能 X = JNP.Arange(1000) #熱身吉特 fast_function(x)#第一個調用編譯功能 #時間比較 start = time.time() slow_result = slow_function(x) 打印(f“沒有jit:{time.time() - 開始:.4f}秒”) start = time.time() fast_result = fast_function(x) 打印(f with jit:{time.time() - 開始:.4f}秒”)
結果將使您驚訝。 JIT彙編比正常彙編快333倍。這就像將自行車與Buggati Chiron進行比較。
輸出:
沒有JIT:0.0330秒 與JIT:0.0010秒
JIT可以為您提供超快速的執行力,但您必須正確使用它,否則就像在沒有提供超級跑車設施的泥濘鄉村道路上駕駛布加迪一樣。
JIT在靜態形狀和類型中最有效。避免使用取決於數組值的python循環和條件。 JIT不適用於動態陣列。
#不好 - 使用Python控制流 @Jit def bad_function(x): 如果x [0]> 0:#這與JIT無法正常工作 返回x 返回-x #print(bad_function(jnp.array([1,2,3]))) #好 - 使用jax控制流 @Jit def good_function(x): 返回jnp.Where(x [0]> 0,x,-x)#jax -native條件 打印(good_function(JNP.Array([1,2,3]))))))
輸出:
這意味著bad_function是不好的,因為JIT在計算過程中不在X的值中。
輸出:
[1 2 3]
自動分化或Autodiff是一種計算技術,用於準確有效地計算功能的導數。它在優化機器學習模型中起著至關重要的作用,尤其是在訓練神經網絡中,該網絡用於更新模型參數。
Autodiff通過將微積分的鏈規則應用於更簡單的功能,計算這些子功能的派生函數,然後結合結果。它在函數執行過程中記錄每個操作以構建計算圖,然後將其用於自動計算衍生物。
自動陷阱有兩種主要模式:
導入jax.numpy作為jnp 從jax進口畢業,value_and_grad #定義一個簡單的神經網絡層 def層(params,x): 重量,偏見=參數 返回jnp.dot(x,重量)偏差 #定義標量值損耗函數 def loss_fn(params,x): 輸出=圖層(參數,x) 返回JNP.SUM(輸出)#還原為標量 #獲得輸出和梯度 layer_grad = grad(loss_fn,argnums = 0)#相對於參數的漸變 layer_value_and_grad = value_and_grad(loss_fn,argnums = 0)#值和漸變 #示例用法 key = jax.random.prngkey(0) x = jax.random.normal(key,(3,4)) 重量= jax.random.normal(key,(4,2)) bias = jax.random.normal(key,(2,)) #計算梯度 grads = layer_grad((重量,偏見),x) 輸出,grads = layer_value_and_grad(((重量,偏見),x) #多個導數很容易 twice_grad = grad(grad(jnp.sin)) X = JNP.Array(2.0) print(f“ sin的第二個衍生物在x = 2:{twice_grad(x)}”)
輸出:
sin的第二個衍生物x = 2:-0.9092974066734314
在JAX中,“ VMAP”是一個強大的函數,可以自動矢量化計算,從而可以在無需手動編寫循環的情況下將功能應用於批次的數據。它可以在陣列軸(或多個軸)上繪製函數,並並行評估它,從而可以顯著改善性能。
VMAP函數可自動化沿輸入陣列的指定軸將函數應用於每個元素的過程,同時保留計算的效率。它轉換給定功能以接受批處理輸入並以矢量化的方式執行計算。
VMAP不是使用顯式循環,而是通過在輸入軸上進行矢量進行並行執行操作。這利用了硬件執行SIMD(單個指令,多個數據)操作的功能,這可能會導致大幅加速。
導入jax.numpy作為jnp 來自JAX導入VMAP #在單個輸入中起作用的功能 def single_input_fn(x): 返回jnp.sin(x)jnp.cos(x) #將其矢量化以在批處理 batch_fn = vmap(single_input_fn) #比較性能 X = JNP.Arange(1000) #沒有VMAP(使用列表理解) result1 = jnp.Array(x In xi in xi]) #與vmap 結果2 = batch_fn(x)#快得多! #矢量化多個參數 def兩_input_fn(x,y): 返回x * jnp.sin(y) #在兩個輸入上進行矢量化 vectorized_fn = vmap(tw_input_fn,in_axes =(0,0)) #或僅通過第一個輸入進行矢量化 partaly_vectorized_fn = vmap(tw_input_fn,in_axes =(0,none)) # 列印 打印(結果1.形) 打印(結果2.形狀) 打印(partaly_vectorized_fn(x,y).shape)
輸出:
(1000,) (1000,) (1000,3)
JAX為矩陣操作和線性代數提供了全面的支持,使其適合科學計算,機器學習和數值優化任務。 JAX的線性代數功能與諸如Numpy之類的庫中的功能相似,但具有其他功能,例如自動差異化和即時彙編,以進行優化的性能。
這些操作是相同形狀的元素矩陣進行的。
#1矩陣加法和減法: 導入jax.numpy作為jnp a = jnp.array([[[1,2],[3,4]]) b = jnp.Array([[[5,6],[7,8]]) #矩陣加法 C = AB #矩陣減法 d = a -b 打印(f“矩陣A:\ n {a}”) 打印(“ ========================== 打印(f“矩陣B:\ n {b}”) 打印(“ ========================== print(f“ ab:\ n {c}”的矩陣adtion”) 打印(“ ========================== 打印(f“ ab:\ n {d}的矩陣縮寫”)
輸出:
JAX支持元素乘法和基於DOR產品的矩陣乘法。
#元素乘法 e = a * b #矩陣乘法(點產品) f = jnp.dot(a,b) 打印(f“矩陣A:\ n {a}”) 打印(“ ========================== 打印(f“矩陣B:\ n {b}”) 打印(“ ========================== print(f“*b:\ n {e}的元素乘法”) 打印(“ ========================== print(f“ a*b:\ n {f}的矩陣乘法”)
輸出:
可以使用`
#矩陣 g = jnp.transpose(a) 打印(f“矩陣A:\ n {a}”) 打印(“ ========================== print(f“ a:\ n {g}的矩陣轉置”)
輸出:
JAX使用jnp.linalg.inv()`提供矩陣反轉的功能
#矩陣倒置 h = jnp.linalg.inv(a) 打印(f“矩陣A:\ n {a}”) 打印(“ ========================== print(f“ a:\ n {h}的矩陣反轉”)
輸出:
可以使用`jnp.linalg.det()``。
#矩陣決定因素 det_a = jnp.linalg.det(a) 打印(f“矩陣A:\ n {a}”) 打印(“ ========================== print(f“ a:\ n {det_a}”的矩陣決定因素”)
輸出:
您可以使用`jnp.linalg.eigh()計算矩陣的特徵值和特徵向量
#特徵值和特徵向量 導入jax.numpy作為jnp a = jnp.array([[[1,2],[3,4]]) 特徵值,特徵向量= jnp.linalg.eigh(a) 打印(f“矩陣A:\ n {a}”) 打印(“ ========================== print(a:\ n {eigenvalues}的f“ egenvalues”) 打印(“ ========================== print(a:\ n {eigenVectors}的f“ eigenVectors}”)
輸出:
通過`jnp.linalg.svd`支持SVD,可用於降低維度和矩陣分解。
#單數值分解(SVD) 導入jax.numpy作為jnp a = jnp.array([[[1,2],[3,4]]) u,s,v = jnp.linalg.svd(a) 打印(f“矩陣A:\ n {a}”) 打印(“ ========================== print(f“ matrix u:\ n {u}”) 打印(“ ========================== 打印(f“矩陣S:\ n {s}”) 打印(“ ========================== 打印(f“矩陣V:\ n {v}”)
輸出:
為了求解線性方程式AX = B的系統,我們使用`jnp.linalg.solve()`,其中A是平方矩陣,B是相同數量的行的向量或矩陣。
#線性方程的求解系統 導入jax.numpy作為jnp a = jnp.array([[[2.0,1.0],[1.0,3.0]]) B = JNP.Array([[5.0,6.0]) x = jnp.linalg.solve(a,b) 打印(f“ x:{x}的值”)
輸出:
x的值:[1.8 1.4]
使用JAX的自動分化,您可以計算標量功能相對於矩陣的梯度。
我們將計算以下功能的梯度和x的值
功能
#計算矩陣函數的梯度 導入JAX 導入jax.numpy作為jnp def matrix_function(x): 返回JNP.SUM(JNP.SIN(X)X ** 2) #計算功能的畢業 grad_f = jax.grad(matrix_function) x = jnp.Array([[[1.0,2.0],[3.0,4.0]])) 漸變= grad_f(x) 打印(f“矩陣x:\ n {x}”) 打印(“ ========================== 打印(f“ matrix_function的梯度:\ n {漸變}”)
輸出:
這些在數值計算,機器學習和物理計算中使用的JAX的最有用的功能。還有更多供您探索。
JAX具有科學計算的強大庫,JAX最適合科學計算,用於其提前特徵,例如JIT彙編,自動分化,矢量化,並行化和GPU-TPU加速度。 JAX支持高性能計算的能力使其適用於廣泛的科學應用,包括物理模擬,機器學習,優化和數值分析。
我們將在本節中探討一個優化問題。
讓我們瀏覽以下優化問題:
#定義一個函數以最小化(例如,Rosenbrock函數) @Jit Def Rosenbrock(X): 返回sum(100.0 *(x [1:] - x [: - 1] ** 2.0)** 2.0(1 -x [: - 1])** 2.0)
在這裡,定義了Rosenbrock函數,這是優化中常見的測試問題。該函數將數組x作為輸入,併計算一個代表x距函數全局最小值的valie。 @JIT裝飾器用於啟用JUT-IN-IN時間彙編,該彙編通過編譯功能在CPU和GPU上有效運行來加快計算的速度。
#梯度下降優化 @Jit def gradient_descent_step(x,Learning_rate): 返回X -Learning_rate * grad(Rosenbrock)(x)
此功能執行梯度下降優化的單一步驟。使用Grad(Rosenbrock)(X)計算Rosenbrock函數的梯度,該級提供了相對於X的導數。 X的新值通過減法更新,通過Learning_rate縮放梯度。@Jit的做法與以前相同。
# 最佳化 x = jnp.array([0.0,0.0])#起點 Learning_rate = 0.001 對於範圍的我(2000年): x = gradient_descent_step(x,Learning_rate) 如果我%100 == 0: print(f“步驟{i},值:{Rosenbrock(x):。4f}”)
優化循環初始化了起點X,並執行梯度下降的1000次迭代。在每次迭代中,gradient_descent_step函數基於當前梯度更新。每100個步驟,當前的步驟編號和X處的Rosenbrock函數的值,提供優化的進度。
輸出:
我們將模擬一個物理系統的運動系統的運動,該運動的運動震盪振盪器的運動模型,該系統像帶有摩擦的質量彈簧系統,車輛中的減震器或電路中的振盪一樣建模。不是很好嗎?我們開始做吧。
導入JAX 導入jax.numpy作為jnp #定義參數 質量= 1.0#對象的質量(kg) 阻尼= 0.1#阻尼係數(kg/s) spring_constant = 1.0#彈簧常數(n/m) #定義時間步驟和總時間 DT = 0.01#時間步長(S) num_steps = 3000#步驟數
定義了質量,阻尼係數和彈簧常數。這些決定了阻尼的諧波振盪器的物理特性。
#定義ODES系統 DEF DAMPED_HARMONIC_COSCILLATOR(狀態,T): “”“計算阻尼諧波振盪器的衍生物。 狀態:包含位置和速度的數組[X,V] T:時間(在此自治系統中不使用) ”“” x,v =狀態 dxdt = v dvdt = -Damping / Mass * V -Spring_constant / Mass * x 返回JNP.Array([DXDT,DVDT])
阻尼的諧波振盪器函數定義了振盪器的位置和速度的衍生物,代表了動力學系統。
#使用Euler的方法解決ODE def euler_step(狀態,t,dt): “”“執行Euler方法的一步。”“” 衍生物= damped_harmonic_coscillator(狀態,t) 返回狀態衍生工具 * DT
一種簡單的數值方法用於求解ode。它在下一個時間步驟近似於當前狀態和導數。
#初始狀態:[位置,速度] oniration_state = jnp.Array([1.0,0.0])#從質量開始,x = 1,v = 0 #時間演變 狀態= [initial_state] 時間= 0.0 對於範圍(num_steps)的步驟: next_state = euler_step(狀態[-1],時間,dt) states.append(next_state) 時間= DT #將狀態列表轉換為JAX數組進行分析 狀態= jnp.stack(狀態)
循環通過指定的時間步驟迭代,使用Euler的方法在每個步驟更新狀態。
輸出:
最後,我們可以繪製結果以可視化阻尼的諧波振盪器的行為。
#繪製結果 導入matplotlib.pyplot作為PLT plt.Style.use(“ GGPLOT”) 位置=狀態[:,0] 速度=狀態[:,1] time_points = jnp.arange(0,(num_steps 1) * dt,dt) plt.figure(無花果=(12,6)) plt.subplot(2,1,1) plt.plot(time_points,位置,label =“位置”) plt.xlabel(“時間”) plt.ylabel(“位置(M)”) plt.legend() plt.subplot(2,1,2) plt.plot(time_points,速度,label =“速度”,color =“橙色”) plt.xlabel(“時間”) plt.ylabel(“速度(m/s)”) plt.legend() plt.tight_layout() plt.show()
輸出:
我知道您渴望看到如何使用JAX構建神經網絡。因此,讓我們深入研究它。
在這裡,您可以看到這些值逐漸最小化。
JAX是一個功能強大的庫,將高性能數值計算與使用Numpy樣語法的易用性結合在一起。本節將指導您使用JAX構建神經網絡的過程,並利用其高級功能進行自動差異化和即時彙編以優化性能。
在我們深入建立神經網絡之前,我們需要進口必要的庫。 JAX提供了一套用於創建有效數值計算的工具,而其他庫將有助於優化和可視化我們的結果。
導入JAX 導入jax.numpy作為jnp 來自Jax Import Grad,Jit 來自jax.random導入prngkey,正常 導入Optax#JAX的優化庫 導入matplotlib.pyplot作為PLT
創建有效的模型層對於定義神經網絡的體系結構至關重要。在此步驟中,我們將初始化密集層的參數,以確保我們的模型從定義明確的權重和偏見開始,以進行有效學習。
def init_layer_params(key,n_in,n_out): “”“單個密集層的初始化參數”“” key_w,key_b = jax.random.split(key) #初始化 w = normal(key_w,(n_in,n_out)) * jnp.sqrt(2.0 / n_in) b = normal(key_b,(n_out,)) * 0.1 返回(w,b) def relu(x): “”“ relu激活函數”“” 返回jnp.maximum(0,x)
正向通行證是神經網絡的基石,因為它決定了輸入數據如何流過網絡以產生輸出。在這裡,我們將通過通過初始化層將轉換應用於輸入數據來定義一種計算模型輸出的方法。
def向前(參數,x): “”“前向兩個層神經網絡”“”“” (W1,B1),(W2,B2)=參數 #第一層 h1 = relu(jnp.dot(x,w1)b1) #輸出層 logits = jnp.dot(h1,w2)b2 返回logits
定義明確的損失功能對於指導我們模型的培訓至關重要。在此步驟中,我們將實施平均誤差(MSE)損耗函數,該函數衡量了預測輸出符合目標值的程度,從而使模型能夠有效學習。
def loss_fn(params,x,y): “”“平均平方錯誤損失”“” pred =向前(params,x) 返回jnp.mean(((pred -y)** 2)
通過定義了模型體系結構和損失函數,我們現在轉向模型初始化。此步驟涉及設置我們的神經網絡的參數,以確保每一層都準備以隨機但適當縮放的權重和偏見開始訓練過程。
def init_model(rng_key,input_dim,hidden_dim,output_dim): key1,key2 = jax.random.split(rng_key) params = [ init_layer_params(key1,input_dim,hidden_dim), init_layer_params(key2,hidden_dim,output_dim), 這是給出的 返回參數
訓練神經網絡涉及基於損耗函數的計算梯度對其參數的迭代更新。在此步驟中,我們將實施一個有效地應用這些更新的培訓功能,從而使我們的模型可以通過多個時期的數據學習。
@Jit def train_step(params,opt_state,x_batch,y_batch): 損失,grads = jax.value_and_grad(loss_fn)(params,x_batch,y_batch) 更新,opt_state =優化器。 params = optax.apply_updates(參數,更新) 返回參數,opt_state,損失
為了有效地培訓我們的模型,我們需要生成合適的數據並實施培訓循環。本節將介紹如何為我們的示例創建合成數據,以及如何跨多個批次和時代管理培訓過程。
#生成一些示例數據 key = prngkey(0) x_data = normal(鍵,(1000,10))#1000樣本,10個功能 y_data = jnp.sum(x_data ** 2,axis = 1,keepdims = true)#簡單的非線性函數 #初始化模型和優化器 params = init_model(key,input_dim = 10,hidden_dim = 32,output_dim = 1) 優化器= optax.adam(Learning_rate = 0.001) opt_state =優化器(params) #訓練循環 batch_size = 32 num_epochs = 100 num_batches = x_data.shape [0] // batch_size #存儲時期和損失值的數組 epoch_array = [] loss_array = [] 對於範圍(num_epochs)的時代: epoch_loss = 0.0 對於範圍(num_batches)的批次: idx = jax.random.permunt(鍵,batch_size) x_batch = x_data [idx] y_batch = y_data [idx] params,opt_state,loss = train_step(params,opt_state,x_batch,y_batch) epoch_loss =損失 #存儲時代的平均損失 avg_loss = epoch_loss / num_batches epoch_array.append(epoch) lose_array.append(avg_loss) 如果epoch%10 == 0: print(f“ epoch {epoch},損失:{avg_loss:.4f}”)
可視化訓練結果是了解我們神經網絡的性能的關鍵。在此步驟中,我們將繪製培訓損失而不是時期,以觀察模型的學習程度並確定培訓過程中的任何潛在問題。
#繪製結果 plt.plot(epoch_array,loss_array,label =“訓練損失”) plt.xlabel(“ Epoch”) plt.ylabel(“損失”) plt.title(“時代訓練損失”) plt.legend() plt.show()
這些示例演示了JAX如何將高性能與乾淨,可讀的代碼結合在一起。 JAX鼓勵的功能編程樣式使組成操作變得容易並應用轉換。
輸出:
陰謀:
這些示例演示了JAX如何將高性能與乾淨,可讀的代碼結合在一起。 JAX鼓勵的功能編程樣式使組成操作變得容易並應用轉換。
在建立神經網絡時,遵守最佳實踐可以顯著提高性能和可維護性。本節將討論各種策略和技巧,以優化您的代碼並提高基於JAX的模型的整體效率。
與JAX合作時,優化性能至關重要,因為它使我們能夠充分利用其功能。在這裡,我們將探索不同的技術來提高JAX功能的效率,以確保我們的模型在不犧牲可讀性的情況下盡快運行。
Just-On-time(JIT)彙編是JAX的出色功能之一,可以通過在運行時編譯功能來更快地執行。本節將概述有效使用JIT編譯的最佳實踐,從而幫助您避免常見的陷阱並最大程度地提高代碼的性能。
導入JAX 導入jax.numpy作為jnp 來自JAX Import Jit 來自JAX Import Lax #不好:動態的python控制流 @Jit def bad_function(x,n): 對於範圍(n)的i:#python循環 - 將展開 x = x 1 返回x 打印(“ ========================== #print(bad_function(1,1000))#不起作用
該函數使用標準的Python循環進行迭代n次,在每次迭代中將X的X遞增1。與JIT一起編譯時,JAX展開了循環,這可能是效率低下的,尤其是對於大型n。這種方法並不能完全利用JAX的功能進行性能。
#好:使用jax-native操作 @Jit def good_function(x,n): 返回xn#矢量化操作 打印(“ ========================== 打印(good_function(1,1000))
該函數執行相同的操作,但是它使用矢量化操作(XN)而不是循環。這種方法更有效,因為當以單個矢量化操作表示時,JAX可以更好地優化計算。
#更好:使用掃描進行循環 @Jit def best_function(x,n): def body_fun(i,val): 返回val 1 返回lax.fori_loop(0,n,body_fun,x) 打印(“ ========================== 打印(best_function(1,1000))
此方法使用`jax.lax.fori_loop`,這是一種有效實現循環的JAX本地方法。 `lax.fori_loop`執行與上一個函數相同的增量操作,但是它使用編譯的循環結構進行操作。 Body_fn函數定義了每次迭代的操作,並且`lax.fori_loop`從o到n執行它。該方法比展開循環更有效,並且特別適用於未知迭代次數的情況。
輸出:
=========================== =========================== 1001 =========================== 1001
該代碼展示了處理循環和控制流程中jax符合JIT符合功能的不同方法。
在任何計算框架中,有效的內存管理至關重要,尤其是在處理大型數據集或複雜模型時。本節將討論內存分配中的常見陷阱,並提供在JAX中優化內存使用情況的策略。
#不好:創建大型臨時陣列 @Jit def infelfficited_function(x): temp1 = jnp.power(x,2)#臨時數組 temp2 = jnp.sin(temp1)#另一個臨時 返回JNP.SUM(temp2)
defficited_function(x):此函數創建多個中間陣列Temp1,temp1,最後創建了temp2中元素的總和。創建這些臨時陣列可能是降低的,因為每個步驟都會分配內存並產生計算開銷,從而導致執行速度較慢和更高的內存使用情況。
#好:結合操作 @Jit def效率_function(x): 返回JNP.SUM(JNP.SIN(JNP.Power(x,2)))#單操作
此版本將所有操作結合到一行代碼中。它直接計算X平方元素的正弦,並總和結果。通過結合操作,它可以避免創建中間陣列,減少內存足跡並提高性能。
x = jnp.Array([1,2,3]) 打印(x) 打印(inffelided_function(x)) 打印(效率_function(x))
輸出:
[1 2 3] 0.49678695 0.49678695
有效的版本利用JAX優化計算圖的能力,通過最大程度地減少臨時數組創建來使代碼更快,更快。
調試是開發過程的重要組成部分,尤其是在復雜的數值計算中。在本節中,我們將討論特定於JAX的有效調試策略,使您能夠快速識別和解決問題。
該代碼顯示了在JAX內調試的技術,尤其是在使用JIT編譯功能時。
導入jax.numpy作為jnp 來自JAX Import Debug @Jit def debug_function(x): #使用debug.print而不是在jit中打印 debug.print(“ x的形狀:{}”,x.Shape) y = jnp.sum(x) debug.print(“ sum:{}”,y) 返回y
#要進行更複雜的調試,請突破JIT def debug_values(x): 打印(“輸入:”,x) 結果= debug_function(x) 打印(“輸出:”,結果) 返回結果
輸出:
打印(“ ========================== 打印(debug_function(jnp.array([1,2,3])))))) 打印(“ ========================== 打印(debug_values(jnp.array([1,2,3])))))))
這種方法允許使用標準的Python打印語句組合使用Debug.print()和更詳細的調試。
最後,我們將探索JAX中的常見模式和成語,可以幫助簡化您的編碼過程並提高效率。熟悉這些實踐將有助於開發更強大和表現的JAX應用程序。
#1。設備內存管理 def process_large_data(數據): #塊中的過程以管理內存 chunk_size = 100 結果= [] 對於i在範圍內(0,len(data),chunk_size): 塊=數據[i:i Chunk_size] chunk_result = jit(process_chunk)(塊) 結果。 返回JNP.Concatenate(結果) def Process_chunk(塊): chunk_temp = jnp.sqrt(塊) 返回chunk_temp
此功能會在塊中處理大型數據集,以避免壓倒性的設備內存。
它將Chunk_size設置為100,並在塊大小的數據增量上進行迭代,並分別處理每個塊。
對於每個塊,該函數使用JIT(Process_chunk)來彙編處理操作,從而通過提前進行編譯來改善性能。
每個塊的結果都使用JNP.Concatenated(結果)將單個列表串成一個單個數組。
輸出:
打印(“ ========================== data = jnp.Arange(10000) 打印(Data.Shape) 打印(“ ========================== 打印(數據) 打印(“ ========================== 打印(process_large_data(數據))
函數create_traing_state()演示了JAX中的隨機數生成器(RNG)的管理,這對於可重複性和一致的結果至關重要。
#2。處理隨機種子 def create_training_state(rng): #用於不同用途的拆分RNG rng,init_rng = jax.random.split(rng) params = init_network(init_rng) 返回參數,rng#返回新的RNG供下一個使用
它從初始RNG(RNG)開始,然後使用jax.random.split()將其分為兩個新的RNG。 Split RNGS執行不同的任務:`init_rng“初始化網絡參數,以及更新的RNG返回以進行後續操作。
該函數返回初始化的網絡參數和新的RNG供進一步使用,以確保在不同步驟中正確處理隨機狀態。
現在使用模擬數據測試代碼
DEF INIT_NETWORK(RNG): #初始化網絡參數 返回 { “ W1”:jax.random.normal(rng,(784,256)), “ b1”:jax.random.normal(rng,(256,)), “ W2”:jax.random.normal(rng,(256,10)), “ b2”:jax.random.normal(rng,(10,)), } 打印(“ ========================== key = jax.random.prngkey(0) 參數,rng = create_training_state(鍵) 打印(f“隨機數生成器:{rng}”) 打印(params.keys()) 打印(“ ========================== 打印(“ ========================== 打印(f“網絡參數形狀:{params ['w1']。形狀}”) 打印(“ ========================== 打印(f“網絡參數形狀:{params ['b1']。形狀}”) 打印(“ ========================== 打印(f“網絡參數形狀:{params ['w2']。形狀}”) 打印(“ ========================== 打印(f“網絡參數形狀:{params ['b2']。形狀}”) 打印(“ ========================== 打印(f“網絡參數:{params}”)
輸出:
def G(x,n): i = 0 當我<n i="1" g_jit_correct="jax.jit(g,static_argnames" n><p><strong>輸出:</strong></p> <pre class="brush:php;toolbar:false"> 30
如果JIT每次使用相同的參數編譯函數,則可以使用靜態參數。這對於JAX函數的性能優化很有用。
從函數引入部分導入 @partial(jax.jit,static_argnames = [“ n”]) def g_jit_decorated(x,n): i = 0 當我<n i="1"><p>如果您想將JIT中的靜態參數用作裝飾器,則可以在功能上使用JIT。 partial()函數。</p> <p><strong>輸出:</strong></p> <pre class="brush:php;toolbar:false"> 30
現在,我們已經學習並深入研究了許多令人興奮的概念和技巧以及整體編程風格。
本文中使用的所有代碼都在這裡
JAX是一種強大的工具,可為機器學習,深度學習和科學計算提供廣泛的功能。從基礎知識開始,進行實驗,並從Jax美麗的文檔和社區獲得幫助。有很多東西要學習,只要閱讀他人的代碼,就不會學到這一點。因此,立即開始在JAX中創建一個小型項目。關鍵是繼續前進,學習途中。
答:儘管JAX感覺就像Numpy,但它增加了自動分化,JIT彙編和GPU/TPU支持。
Q2。我需要GPU使用JAX嗎?答:在一個單詞中,儘管擁有GPU可以顯著加快較大數據的計算。
Q3。 JAX是Numpy的好替代品嗎?答:是的,您可以將JAX用作Numpy的替代方法,儘管如果您很好地使用JAX的功能,Jax的API看起來對Numpy Jax熟悉,則更強大。
Q4。我可以將現有的Numpy代碼與JAX一起使用嗎?答:大多數Numpy代碼可以以最小的更改適應JAX。通常只是將導入numpy作為NP將其導入JAX.numpy作為JNP。
Q5。 Jax比Numpy更難學習嗎?答:基礎知識和numpy一樣容易!告訴我一件事,閱讀上述文章和動手完成後,您會發現很難嗎?我為你回答。是的,很難。每個框架,語言,庫都很難,不是因為設計很難,而是因為我們沒有花太多時間來探索它。讓它有時間弄髒您的手,每天都會更容易。
本文所示的媒體不由Analytics Vidhya擁有,並由作者酌情使用。
以上是閃電般的JAX指南的詳細內容。更多資訊請關注PHP中文網其他相關文章!