import torch
PyTorch torch.lerp Exploration
Introduction
torch.lerp
stands for linear interpolation is a handy function that combines two tensors using a provided weight. Let’s explore how it can be used!
Let’s start off with something easy. We have two items and we want to combine them by taking 75% of item1 and 25% of item2. Mathematically, this could be represented as \(output = 0.75*item1+0.25*item2\). A more general form of this can be represented as \(output = pct*item1+(1-pct)*item2\). This is a very common piece of code in machine learning papers. That’s why pytorch has the handy torch.lerp
function!
= torch.tensor(2.)
item1 = torch.tensor(6.)
item2 = 1/4 # This means that we will use 3/4 of item1 and 1/4 of item2 weight
= (1-weight)*item1+(weight)*item2 output1
= torch.lerp(item1, item2, weight) output2
output1
tensor(3.)
output2
tensor(3.)
Here is an example in the mixup paper of lerp being used in practice:
import matplotlib.pyplot as plt
= plt.imread('notebook_images/pets/april.jpg')
np_april = torch.from_numpy(np_april) april
= april[600:600+1224,1100:1100+1124,:]/255. april_smaller
= april_smaller #simulated image #1
x_i = torch.rand_like(x_i) #simulated image #2
x_j =0.1 # Let's set lam to 0.5 which will blend equal parts of xi and xj. lam
Now, let’s blend these two ‘images’
= plt.subplots(ncols=5, sharey=True, figsize=(18,3))
fig, axs for i,lam in enumerate([0, 0.25, 0.5, 0.75, 1]):
= torch.lerp(x_j,x_i,lam)
x_hat f'{i}:λ={lam}')
axs[i].set_title( axs[i].imshow(x_hat)
from fastcore.all import test_close, test_eq
= torch.lerp(x_j,x_i,weight=0.5) x_hat
+ x_i)/2, x_hat, eps=1e-6) test_close((x_j
As we expected, these two are equal (within a small amount of error due to float math)
Linear interpolation is also often used in exponentially weighted decay which allows us to not entirely discard previous weight results while only keeping track of the most recent value.
Here is what exponential weighted decay looks like in the Adam Optimizer formula:
)
This algorithm actually contains two linear interpolations:
\(m_t = \beta_1*m_{t-1}+(1-\beta_1)*g_t\)
\(v_t = \beta_2*v_{t-1}+(1-\beta_2)*g_t^2\)
and here is what they look like in code:
=torch.tensor(0.)
m_tm1=torch.tensor(0.)
v_tm1= torch.tensor(0.5)
g_t =torch.tensor(0.99)
beta1=torch.tensor(0.999) beta2
= torch.lerp(m_tm1, g_t, beta1)
m_t = torch.lerp(v_tm1, g_t**2, beta2) v_t
m_t
tensor(0.4950)
v_t
tensor(0.2498)
Hope this was helpful and gave a better understanding of what torch.lerp is and where it is used. If you have any suggestions or questions, please feel free to reach out and I would be happy to address them!