How to use torch.fft to apply a high pass filter to an image.

Kai
3 min readDec 17, 2020

Pytorch has been upgraded to 1.7 and fft (Fast Fourier Transform) is now available on pytorch. In this article, we will use torch.fft to apply a high pass filter to an image.

Image to use

Of cource!(lena.jpg)

Contents

  1. Import modules
  2. Read image
  3. FFT!
  4. Make fftshift
  5. Apply high pass filter
  6. Inverse FFT !
  7. Check the results.

Implementation

  1. Import modules
import torch.fft
import torch
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

2. Read image

img = Image.open(“lena.jpg”)
img = img.convert(‘L’)
img = np.array(img)
img = torch.from_numpy(img)
print(img.shape) # (512, 512)

3. FFT!

fft_img = torch.fft.fft(img)
print(fft_img.shape) # torch.Size([512, 512])

It’s very easy. The code of FFT is not difficult if you understand about it, but with this, you don’t even need to understand. The content of the FFT looks like below. It shows how much of each frequency component there is.

print(fft_img[0][:5])
# tensor([ 55542.0000+0.0000j,
# 1218.6273+3518.8711j,
# -847.8992+99.0569j,
# 1609.7048–4533.7368j,
# 2331.2732–2991.2090j])

4. Make fftshift

In order to apply high pass filter, I want to reorder the fft_img so that the low_frequency components are at the center of the image. There is no corresponding np.fft.fftshift implementated in numpy, so I will write one myself.

def roll_n(X, axis, n):
f_idx = tuple(slice(None, None, None)
if i != axis else slice(0, n, None)
for i in range(X.dim()))
b_idx = tuple(slice(None, None, None)
if i != axis else slice(n, None, None)
for i in range(X.dim()))
front = X[f_idx]
back = X[b_idx]
return torch.cat([back, front], axis)
def fftshift(X):
real, imag = X.chunk(chunks=2, dim=-1)
real, imag = real.squeeze(dim=-1), imag.squeeze(dim=-1)
for dim in range(2, len(real.size())):
real = roll_n(real, axis=dim,
n=int(np.ceil(real.size(dim) / 2)))
imag = roll_n(imag, axis=dim,
n=int(np.ceil(imag.size(dim) / 2)))
real, imag = real.unsqueeze(dim=-1), imag.unsqueeze(dim=-1)
X = torch.cat((real,imag),dim=1)
return torch.squeeze(X)
def ifftshift(X):
real, imag = X.chunk(chunks=2, dim=-1)
real, imag = real.squeeze(dim=-1), imag.squeeze(dim=-1)

for dim in range(len(real.size()) — 1, 1, -1):
real = roll_n(real, axis=dim,
n=int(np.floor(real.size(dim) / 2)))
imag = roll_n(imag, axis=dim,
n=int(np.floor(imag.size(dim) / 2)))
real, imag = real.unsqueeze(dim=-1), imag.unsqueeze(dim=-1)
X = torch.cat((real, imag), dim=1)
return torch.squeeze(X)

5. Apply high pass filter

Decide the filter size. In this case, , let’s remove 95% of the image.

fft_shift_img = fftshift(fft_img)filter_rate = 0.95
h, w = fft_shift_img.shape[:2] # height and width
cy, cx = int(h/2), int(w/2) # centerness
rh, rw = int(filter_rate * cy), int(filter_rate * cx) # filter_size
# the value of center pixel is zero.
fft_shift_img[cy-rh:cy+rh, cx-rw:cx+rw] = 0

6. Inverse FFT !

# restore the frequency image
ifft_shift_img = ifftshift(fft_shift_img)
# inverce fft
ifft_img = torch.fft.ifft(ifft_shift_img)

7. Check the results.

ifft_img = ifft_img.to(‘cpu’).detach().numpy().copy()
ifft_img = ifft_img.real.astype(np.int)
plt.imshow(ifft_img, cmap=”gray”)
plt.show()

I was able to confirm that the filter was properly applied :)

--

--