Pytorch has been upgraded to 1.7 and fft (Fast Fourier Transform) is now available on pytorch. In this article, we will use torch.fft to apply a high pass filter to an image.
Image to use
Contents
- Import modules
- Read image
- FFT!
- Make fftshift
- Apply high pass filter
- Inverse FFT !
- Check the results.
Implementation
- Import modules
import torch.fft
import torch
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
2. Read image
img = Image.open(“lena.jpg”)
img = img.convert(‘L’)
img = np.array(img)
img = torch.from_numpy(img)
print(img.shape) # (512, 512)
3. FFT!
fft_img = torch.fft.fft(img)
print(fft_img.shape) # torch.Size([512, 512])
It’s very easy. The code of FFT is not difficult if you understand about it, but with this, you don’t even need to understand. The content of the FFT looks like below. It shows how much of each frequency component there is.
print(fft_img[0][:5])
# tensor([ 55542.0000+0.0000j,
# 1218.6273+3518.8711j,
# -847.8992+99.0569j,
# 1609.7048–4533.7368j,
# 2331.2732–2991.2090j])
4. Make fftshift
In order to apply high pass filter, I want to reorder the fft_img so that the low_frequency components are at the center of the image. There is no corresponding np.fft.fftshift implementated in numpy, so I will write one myself.
def roll_n(X, axis, n):
f_idx = tuple(slice(None, None, None)
if i != axis else slice(0, n, None)
for i in range(X.dim()))
b_idx = tuple(slice(None, None, None)
if i != axis else slice(n, None, None)
for i in range(X.dim()))
front = X[f_idx]
back = X[b_idx]
return torch.cat([back, front], axis)def fftshift(X):
real, imag = X.chunk(chunks=2, dim=-1)
real, imag = real.squeeze(dim=-1), imag.squeeze(dim=-1) for dim in range(2, len(real.size())):
real = roll_n(real, axis=dim,
n=int(np.ceil(real.size(dim) / 2)))
imag = roll_n(imag, axis=dim,
n=int(np.ceil(imag.size(dim) / 2)))
real, imag = real.unsqueeze(dim=-1), imag.unsqueeze(dim=-1)
X = torch.cat((real,imag),dim=1)
return torch.squeeze(X)def ifftshift(X):
real, imag = X.chunk(chunks=2, dim=-1)
real, imag = real.squeeze(dim=-1), imag.squeeze(dim=-1)
for dim in range(len(real.size()) — 1, 1, -1):
real = roll_n(real, axis=dim,
n=int(np.floor(real.size(dim) / 2)))
imag = roll_n(imag, axis=dim,
n=int(np.floor(imag.size(dim) / 2)))
real, imag = real.unsqueeze(dim=-1), imag.unsqueeze(dim=-1)
X = torch.cat((real, imag), dim=1)
return torch.squeeze(X)
5. Apply high pass filter
Decide the filter size. In this case, , let’s remove 95% of the image.
fft_shift_img = fftshift(fft_img)filter_rate = 0.95
h, w = fft_shift_img.shape[:2] # height and width
cy, cx = int(h/2), int(w/2) # centerness
rh, rw = int(filter_rate * cy), int(filter_rate * cx) # filter_size# the value of center pixel is zero.
fft_shift_img[cy-rh:cy+rh, cx-rw:cx+rw] = 0
6. Inverse FFT !
# restore the frequency image
ifft_shift_img = ifftshift(fft_shift_img)# inverce fft
ifft_img = torch.fft.ifft(ifft_shift_img)
7. Check the results.
ifft_img = ifft_img.to(‘cpu’).detach().numpy().copy()
ifft_img = ifft_img.real.astype(np.int)
plt.imshow(ifft_img, cmap=”gray”)
plt.show()
I was able to confirm that the filter was properly applied :)