server.py 6.48 KB
Newer Older
Matteo's avatar
update    
Matteo committed
1
import os
Matteo's avatar
update    
Matteo committed
2
from concurrent import futures
Matteo's avatar
update    
Matteo committed
3
4
5
from typing import Any, Callable
import grpc
from grpc import StatusCode
Matteo's avatar
update    
Matteo committed
6
from rich.console import Console
Matteo's avatar
update    
Matteo committed
7

Matteo's avatar
update  
Matteo committed
8
from mpai_cae_arp.files import File, FileType
Matteo's avatar
update    
Matteo committed
9
from mpai_cae_arp.types.irregularity import IrregularityFile, Source
Matteo's avatar
update    
Matteo committed
10
11
from mpai_cae_arp.network import arp_pb2_grpc as arp_pb2_grpc
from mpai_cae_arp.network.arp_pb2 import (
Matteo's avatar
update    
Matteo committed
12
13
    JobRequest,
    JobResponse,
Matteo's avatar
update    
Matteo committed
14
    Contact,
Matteo's avatar
update    
Matteo committed
15
    InfoResponse,
Matteo's avatar
update    
Matteo committed
16
17
    License,
)
Matteo's avatar
update    
Matteo committed
18
19
20

import segment_finder as sf
import classifier as cl
Matteo's avatar
update  
Matteo committed
21
22
23
24

info = File('config/server.yaml', FileType.YAML).get_content()


Matteo's avatar
update    
Matteo committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def try_or_error_response(
    context,
    on_success_message: str,
    on_error_message: str,
    func: Callable,
    args,
    on_success_status: StatusCode = StatusCode.OK,
    on_error_status: StatusCode = StatusCode.INTERNAL,
) -> tuple[JobResponse, Any]:
    try:
        result = func(*args)
        context.set_code(on_success_status)
        context.set_details(on_success_message)
        return JobResponse(status="success", message=on_success_message), result
    except:
        context.set_code(on_error_status)
        context.set_details(on_error_message)
        return JobResponse(status="error", message=on_error_message), None


def error_response(context, status, message):
    context.set_code(status)
    context.set_details(message)
    return JobResponse(status="error", message=message)


class AudioAnalyserServicer(arp_pb2_grpc.AIMServicer):
Matteo's avatar
update    
Matteo committed
52
53
54
55

    def __init__(self, console: Console):
        self.console = console

Matteo's avatar
update    
Matteo committed
56
    def getInfo(self, request, context) -> InfoResponse:
Matteo's avatar
update    
Matteo committed
57
        self.console.log('Received request for AIM info')
Matteo's avatar
update    
Matteo committed
58
59
60
61
62

        context.set_code(StatusCode.OK)
        context.set_details('Success')

        return InfoResponse(
Matteo's avatar
update    
Matteo committed
63
64
65
66
67
68
69
70
71
72
73
74
75
            title=info['title'],
            description=info['description'],
            version=info['version'],
            contact=Contact(
                name=info['contact']['name'],
                email=info['contact']['email'],
            ),
            license=License(
                name=info['license_info']['name'],
                url=info['license_info']['url'],
            )
        )

Matteo's avatar
update    
Matteo committed
76
    def work(self, request: JobRequest, context):
matteospanio's avatar
update    
matteospanio committed
77

Matteo's avatar
update    
Matteo committed
78
79
80
        self.console.log('Received request for computation')
        self.console.log(request)
        
matteospanio's avatar
update    
matteospanio committed
81
82
        working_dir: str = request.working_dir
        files_name: str = request.files_name
Matteo's avatar
update    
Matteo committed
83
        index: int = request.index
matteospanio's avatar
update    
matteospanio committed
84

Matteo's avatar
update    
Matteo committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        audio_src = os.path.join(working_dir, "PreservationAudioFile", f"{files_name}.wav")
        video_src = os.path.join(working_dir, "PreservationAudioVisualFile", f"{files_name}.mov")

        temp_dir = os.path.join(working_dir, "temp", files_name)
        audio_irreg_1 = os.path.join(temp_dir, "AudioAnalyser_IrregularityFileOutput_1.json")
        audio_irreg_2 = os.path.join(temp_dir, "AudioAnalyser_IrregularityFileOutput_2.json")
        video_irreg_1 = os.path.join(temp_dir, "VideoAnalyser_IrregularityFileOutput_1.json")
        
        if index == 1:

            response, _ = try_or_error_response(
                context,
                func=os.makedirs,
                on_success_message="Folders created successfully",
                on_error_message="Unable to create temporary directory, output path already exists",
                on_error_status=StatusCode.ALREADY_EXISTS,
                args=temp_dir
Matteo's avatar
update    
Matteo committed
102
            )
Matteo's avatar
update    
Matteo committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
            yield response


            response, irreg1 = try_or_error_response(
                context,
                func=sf.create_irreg_file,
                args=(audio_src, video_src),
                on_success_message=f"Found irregularities in Audio source",
                on_error_message="Failed to create irregularity file 1",
            )
            yield response

            try:
                File(audio_irreg_1, FileType.JSON).write_content(irreg1.to_json())
                context.set_code(StatusCode.OK)
                yield JobResponse(status="success", message="Irregularity file 1 saved to disk")
            except:
                yield error_response(context, StatusCode.INTERNAL, "Failed to save irregularity file 1")

        if index == 2:
            
            response, irreg2 = try_or_error_response(
                context,
                func=sf.merge_irreg_files,
                args=(irreg1, IrregularityFile.from_json(video_irreg_1)),
                on_success_message="Irregularity files merged successfully",
                on_error_message="Failed to merge irregularity files",
            )
            yield response

            response, irreg2 = try_or_error_response(
                context,
                func=sf.extract_audio_irregularities,
                args=(audio_src, irreg2, temp_dir),
                on_success_message="Audio irregularities extracted",
                on_error_message="Failed to extract audio irregularities",
            )
            yield response

            response, irregularities_features = try_or_error_response(
                context,
                func=cl.extract_features,
                args=irreg2.irregularities,
                on_success_message="Audio irregularities features extracted",
                on_error_message="Failed to extract audio irregularities features",
            )
            yield response

            response, classification_results = try_or_error_response(
                context,
                func=cl.classify,
                args=irregularities_features,
                on_success_message="Audio irregularities classified",
                on_error_message="Failed to classify audio irregularities",
            )
            yield response

            for irreg, classification_result in zip(irreg2.irregularities, classification_results):
                if irreg.source == Source.AUDIO:
                    irreg.irregularity_type = classification_result.get_irregularity_type()
                    irreg.irregularity_properties = classification_result if classification_result.get_irregularity_type() is not None else None

            File(audio_irreg_2, FileType.JSON).write_content(irreg2.to_json())
            yield JobResponse(status="success", message="Irregularity file 2 created")
Matteo's avatar
update    
Matteo committed
167
168
169
170


def serve(console):
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
Matteo's avatar
update    
Matteo committed
171
    arp_pb2_grpc.add_AIMServicer_to_server(AudioAnalyserServicer(console), server)
Matteo's avatar
update    
Matteo committed
172
173
174
175
176
177
178
179
    server.add_insecure_port('[::]:50051')
    server.start()
    server.wait_for_termination()


if __name__ == '__main__':
    console = Console()
    console.print('Server started at localhost:50051 :satellite:')
Matteo's avatar
update    
Matteo committed
180
    serve(console)