app.py 5.2 KB
Newer Older
Matteo's avatar
Matteo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from rich.console import Console
import os
from argparse import ArgumentParser

import grpc
from mpai_cae_arp.network import arp_pb2
from mpai_cae_arp.network import arp_pb2_grpc

channels = {
    "AudioAnalyser": grpc.insecure_channel("[::]:50051"),
    "VideoAnalyser": grpc.insecure_channel("[::]:50052"),
    "TapeIrregularityClassifier": grpc.insecure_channel("[::]:50053"),
    "TapeAudioRestoration": grpc.insecure_channel("[::]:50054"),
    "Packager": grpc.insecure_channel("[::]:50055"),
}

def get_args() -> tuple[str, str]:
    parser = ArgumentParser(
        prog="mpai-cae-arp",
        description="The MPAI-CAE ARP client.",
    )
    parser.add_argument("-f", "--files-name", help="Specify the name of the Preservation files (without extension)", required=True)
    parser.add_argument("-s", "--host", help="Specify the host of the server (default is localhost)", default="localhost")

    args = parser.parse_args()
    
    return args.files_name, args.host


def run():
    console = Console()
    files_name, host = get_args()

    if host != "localhost":
        console.print("[red bold]Warning![/] :warning:")
        console.print("[italic red]You are not using the default host. Make sure you are using the same host as the server.[/]")
        exit(os.EX_USAGE)
        
    audio_analyser = arp_pb2_grpc.AIMStub(channels["AudioAnalyser"])
    video_analyser = arp_pb2_grpc.AIMStub(channels["VideoAnalyser"])
    tape_irreg_classifier = arp_pb2_grpc.AIMStub(channels["TapeIrregularityClassifier"])
    tape_audio_restoration = arp_pb2_grpc.AIMStub(channels["TapeAudioRestoration"])
    packager = arp_pb2_grpc.AIMStub(channels["Packager"])

    request = arp_pb2.InfoRequest()
    for aim in [audio_analyser, video_analyser, tape_irreg_classifier, tape_audio_restoration, packager]:
        response = aim.getInfo(request)
        console.print("[bold]{}[/], v{}".format(response.title, response.version))

    request = arp_pb2.JobRequest(
        working_dir="/data",
        files_name=files_name,
        index=1,
    )

    with console.status("[bold]Computing AudioAnalyser IrregularityFile 1...", spinner="bouncingBall"):
        for result in audio_analyser.work(request):
            if result.status == "error":
                console.print("[bold red]Error![/] :boom:")
                console.print(f"[italic red]{result.message}")
                for channel in channels.values():
                    channel.close()
                exit(os.EX_SOFTWARE)
            console.print(result.message)

    request.files_name = f"{files_name}.mov"
    with console.status("[bold]Computing VideoAnalyser IrregularityFiles...", spinner="bouncingBall"):
        for result in video_analyser.work(request):
            if result.status == "error":
                console.print("[bold red]Error![/] :boom:")
                console.print(f"[italic red]{result.message}")
                for channel in channels.values():
                    channel.close()
                exit(os.EX_SOFTWARE)
            console.print(result.message)

    request.index = 2
    request.files_name = files_name
    with console.status("[bold]Computing AudioAnalyser IrregularityFile 2...", spinner="bouncingBall"):
        for result in audio_analyser.work(request):
            if result.status == "error":
                console.print("[bold red]Error![/] :boom:")
                console.print(f"[italic red]{result.message}")
                for channel in channels.values():
                    channel.close()
                exit(os.EX_SOFTWARE)
            console.print(result.message)

    with console.status("[bold]Computing TapeIrregularityClassifier...", spinner="bouncingBall"):
        for result in tape_irreg_classifier.work(request):
            if result.status == "error":
                console.print("[bold red]Error![/] :boom:")
                console.print(f"[italic red]{result.message}")
                for channel in channels.values():
                    channel.close()
                exit(os.EX_SOFTWARE)
            console.print(result.message)

    with console.status("[bold]Computing TapeAudioRestoration...", spinner="bouncingBall"):
        for result in tape_audio_restoration.work(request):
            if result.status == "error":
                console.print("[bold red]Error![/] :boom:")
                console.print(f"[italic red]{result.message}")
                for channel in channels.values():
                    channel.close()
                exit(os.EX_SOFTWARE)
            console.print(result.message)

    with console.status("[bold]Packaging...", spinner="bouncingBall"):
        for result in packager.work(request):
            if result.status == "error":
                console.print("[bold red]Error![/] :boom:")
                console.print(f"[italic red]{result.message}")
                for channel in channels.values():
                    channel.close()
                exit(os.EX_SOFTWARE)
            console.print(result.message)

    channels["AudioAnalyser"].close()
    channels["VideoAnalyser"].close()
    channels["TapeIrregularityClassifier"].close()
    channels["TapeAudioRestoration"].close()
    channels["Packager"].close()

    console.print("[bold green]Success![/] :tada:")


if __name__ == '__main__':
    run()