from rich.console import Console
import os
from argparse import ArgumentParser

import grpc
from mpai_cae_arp.network import arp_pb2
from mpai_cae_arp.network import arp_pb2_grpc

channels = {
    "AudioAnalyser": grpc.insecure_channel("[::]:50051"),
    "VideoAnalyser": grpc.insecure_channel("[::]:50052"),
    "TapeIrregularityClassifier": grpc.insecure_channel("[::]:50053"),
    "TapeAudioRestoration": grpc.insecure_channel("[::]:50054"),
    "Packager": grpc.insecure_channel("[::]:50055"),
}

def get_args() -> tuple[str, str]:
    parser = ArgumentParser(
        prog="mpai-cae-arp",
        description="The MPAI-CAE ARP client.",
    )
    parser.add_argument("-f", "--files-name", help="Specify the name of the Preservation files (without extension)", required=True)
    parser.add_argument("-s", "--host", help="Specify the host of the server (default is localhost)", default="localhost")

    args = parser.parse_args()
    
    return args.files_name, args.host


def run():
    console = Console()
    files_name, host = get_args()

    if host != "localhost":
        console.print("[red bold]Warning![/] :warning:")
        console.print("[italic red]You are not using the default host. Make sure you are using the same host as the server.[/]")
        exit(os.EX_USAGE)
        
    audio_analyser = arp_pb2_grpc.AIMStub(channels["AudioAnalyser"])
    video_analyser = arp_pb2_grpc.AIMStub(channels["VideoAnalyser"])
    tape_irreg_classifier = arp_pb2_grpc.AIMStub(channels["TapeIrregularityClassifier"])
    tape_audio_restoration = arp_pb2_grpc.AIMStub(channels["TapeAudioRestoration"])
    packager = arp_pb2_grpc.AIMStub(channels["Packager"])

    request = arp_pb2.InfoRequest()
    for aim in [audio_analyser, video_analyser, tape_irreg_classifier, tape_audio_restoration, packager]:
        response = aim.getInfo(request)
        console.print("[bold]{}[/], v{}".format(response.title, response.version))

    request = arp_pb2.JobRequest(
        working_dir="/data",
        files_name=files_name,
        index=1,
    )

    with console.status("[bold]Computing AudioAnalyser IrregularityFile 1...", spinner="bouncingBall"):
        for result in audio_analyser.work(request):
            if result.status == "error":
                console.print("[bold red]Error![/] :boom:")
                console.print(f"[italic red]{result.message}")
                for channel in channels.values():
                    channel.close()
                exit(os.EX_SOFTWARE)
            console.print(result.message)

    request.files_name = f"{files_name}.mov"
    with console.status("[bold]Computing VideoAnalyser IrregularityFiles...", spinner="bouncingBall"):
        for result in video_analyser.work(request):
            if result.status == "error":
                console.print("[bold red]Error![/] :boom:")
                console.print(f"[italic red]{result.message}")
                for channel in channels.values():
                    channel.close()
                exit(os.EX_SOFTWARE)
            console.print(result.message)

    request.index = 2
    request.files_name = files_name
    with console.status("[bold]Computing AudioAnalyser IrregularityFile 2...", spinner="bouncingBall"):
        for result in audio_analyser.work(request):
            if result.status == "error":
                console.print("[bold red]Error![/] :boom:")
                console.print(f"[italic red]{result.message}")
                for channel in channels.values():
                    channel.close()
                exit(os.EX_SOFTWARE)
            console.print(result.message)

    with console.status("[bold]Computing TapeIrregularityClassifier...", spinner="bouncingBall"):
        for result in tape_irreg_classifier.work(request):
            if result.status == "error":
                console.print("[bold red]Error![/] :boom:")
                console.print(f"[italic red]{result.message}")
                for channel in channels.values():
                    channel.close()
                exit(os.EX_SOFTWARE)
            console.print(result.message)

    with console.status("[bold]Computing TapeAudioRestoration...", spinner="bouncingBall"):
        for result in tape_audio_restoration.work(request):
            if result.status == "error":
                console.print("[bold red]Error![/] :boom:")
                console.print(f"[italic red]{result.message}")
                for channel in channels.values():
                    channel.close()
                exit(os.EX_SOFTWARE)
            console.print(result.message)

    with console.status("[bold]Packaging...", spinner="bouncingBall"):
        for result in packager.work(request):
            if result.status == "error":
                console.print("[bold red]Error![/] :boom:")
                console.print(f"[italic red]{result.message}")
                for channel in channels.values():
                    channel.close()
                exit(os.EX_SOFTWARE)
            console.print(result.message)

    channels["AudioAnalyser"].close()
    channels["VideoAnalyser"].close()
    channels["TapeIrregularityClassifier"].close()
    channels["TapeAudioRestoration"].close()
    channels["Packager"].close()

    console.print("[bold green]Success![/] :tada:")


if __name__ == '__main__':
    run()
