plot.py 2.04 KB
Newer Older
valentini's avatar
valentini committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os

import glob
import pandas
import pandas as pd
import matplotlib.pyplot as plt
import argparse

def plot_psnrs(batch: int, lr: float, epochs: int, plot_dir: str, metrics: pd.DataFrame) -> None:
    """ Plot the losses and the psnr
        Input:
            - batch
            - lr
            - epochs
            aqaaa
            - plot_dir 
            - metrics
    """
    plt.figure(1)
    plt.plot(metrics["mse_train"].values, label = "Train loss")
    plt.plot(metrics["mse_val"].values, label = "Valid loss")
    plt.title(
        f"Batch Size: {batch} | Learning Rate: {lr} | Epochs: {epochs} ")
    plt.legend()
    plt.savefig(os.path.join(plot_dir, "fig_losses.png"))

    plt.figure(2)
    plt.plot(metrics["psnr_train"].values, label = "Train PSNR")
    plt.plot(metrics["psnr_val"].values, label = "Valid PSNR")
    plt.title(
        f"Batch Size: {batch} | Learning Rate: {lr} | Epochs: {epochs} ")
    plt.legend()
    plt.savefig(os.path.join(plot_dir, "psnr.png"))
    plt.close("all")

    return None

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Plotter for csv graphs',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('csv_file', type=str, help='Path to the folder where the CSV file is located')
    parser.add_argument('-b', '--batch', type=int, default=2, help='Batch size of Reference') # used only for legend
    parser.add_argument('-e', '--epoch', type=int, default=50, help='Number of  the training epoch ') # used only for legend
    parser.add_argument('--lr', type=float, default=10e-5, help='Learning rate of Reference') # used only for legend
 
    args = parser.parse_args()

    csv_file_paths = glob.glob(os.path.normpath(args.csv_file) + "*/**/log.csv")
    for csv_file_path in csv_file_paths:
        print("Processing: " + csv_file_path)
        out_dir = os.path.dirname(csv_file_path)

        data = pandas.read_csv(csv_file_path)

        plot_psnrs(args.batch, args.lr, args.epoch, out_dir, metrics=data)

        data = None