File size: 2,046 Bytes
d65c9b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from PIL import Image
from torchvision import transforms

from src.lpips import LPIPS
import torch.nn as nn

dev = 'cuda'
to_tensor_transform = transforms.Compose([transforms.ToTensor()])
mse_loss = nn.MSELoss()

def calculate_l2_difference(image1, image2, device = 'cuda'):
    if isinstance(image1, Image.Image):
        image1 = to_tensor_transform(image1).to(device)
    if isinstance(image2, Image.Image):
        image2 = to_tensor_transform(image2).to(device)

    mse = mse_loss(image1, image2).item()
    return mse

def calculate_psnr(image1, image2, device = 'cuda'):
    max_value = 1.0
    if isinstance(image1, Image.Image):
        image1 = to_tensor_transform(image1).to(device)
    if isinstance(image2, Image.Image):
        image2 = to_tensor_transform(image2).to(device)
    
    mse = mse_loss(image1, image2)
    psnr = 10 * torch.log10(max_value**2 / mse).item()
    return psnr


loss_fn = LPIPS(net_type='vgg').to(dev).eval()

def calculate_lpips(image1, image2, device = 'cuda'):
    if isinstance(image1, Image.Image):
        image1 = to_tensor_transform(image1).to(device)
    if isinstance(image2, Image.Image):
        image2 = to_tensor_transform(image2).to(device)
    
    loss = loss_fn(image1, image2).item()
    return loss

def calculate_metrics(image1, image2, device = 'cuda', size=(512, 512)):
    if isinstance(image1, Image.Image):
        image1 = image1.resize(size)
        image1 = to_tensor_transform(image1).to(device)
    if isinstance(image2, Image.Image):
        image2 = image2.resize(size)
        image2 = to_tensor_transform(image2).to(device)
        
    l2 = calculate_l2_difference(image1, image2, device)
    psnr = calculate_psnr(image1, image2, device)
    lpips = calculate_lpips(image1, image2, device)
    return {"l2": l2, "psnr": psnr, "lpips": lpips}

def get_empty_metrics():
    return {"l2": 0, "psnr": 0, "lpips": 0}

def print_results(results):
    print(f"Reconstruction Metrics: L2: {results['l2']},\t PSNR: {results['psnr']},\t LPIPS: {results['lpips']}")