Commit 17c75969 authored by Matteo Spanio's avatar Matteo Spanio
Browse files

Upload New File

parent 4f7a8fc5
Pipeline #29 canceled with stages
import os
from concurrent import futures
import grpc
from grpc import StatusCode
from rich.console import Console
import cv2
import numpy as np
from tape_irregularity_classifier.lib import (
CLASSES,
get_irregularity_file,
load_model,
MODEL_PATH,
prepare_images_for_prediction,
PROB
)
from mpai_cae_arp.files import File, FileType
from mpai_cae_arp.types.irregularity import IrregularityFile, Irregularity, IrregularityType
from mpai_cae_arp.network import arp_pb2_grpc as arp_pb2_grpc
from mpai_cae_arp.network.arp_pb2 import (
JobRequest,
JobResponse,
Contact,
InfoResponse,
License,
)
PORT = os.getenv("PORT") or '50051'
info = File('config.yml', FileType.YAML).get_content()
def error_response(context, status, message):
context.set_code(status)
context.set_details(message)
return JobResponse(status="error", message=message)
class TapeIrregularityClassifierServicer(arp_pb2_grpc.AIMServicer):
def __init__(self, console: Console):
self.console = console
def getInfo(self, request, context) -> InfoResponse:
self.console.log('Received request for AIM info')
context.set_code(StatusCode.OK)
context.set_details('Success')
return InfoResponse(
title=info['title'],
description=info['description'],
version=info['version'],
contact=Contact(
name=info['contact']['name'],
email=info['contact']['email'],
),
license=License(
name=info['license_info']['name'],
url=info['license_info']['url'],
)
)
def work(self, request: JobRequest, context):
self.console.log('Received request for computation')
self.console.log(request)
working_dir: str = request.working_dir
files_name: str = request.files_name
temp_dir = os.path.join(working_dir, "temp", files_name)
video_irreg_2 = os.path.join(temp_dir, "VideoAnalyser_IrregularityFileOutput2.json")
audio_irreg_2 = os.path.join(temp_dir, "AudioAnalyser_IrregularityFileOutput2.json")
tape_irreg_1 = os.path.join(temp_dir,'TapeIrregularityClassifier_IrregularityFileOutput1.json' )
tape_irreg_2 = os.path.join(temp_dir,'TapeIrregularityClassifier_IrregularityFileOutput2.json' )
audio_irreg_2 = IrregularityFile.from_json(File(audio_irreg_2, FileType.JSON).get_content())
model = load_model(MODEL_PATH)
# Open Irregularity File from Video Analyser
try:
irregularity_file = get_irregularity_file(video_irreg_2)
except FileNotFoundError:
yield error_response(context, StatusCode.NOT_FOUND, f"{video_irreg_2} not found!")
# Get Irregularities array
input_irregularities = irregularity_file.irregularities
yield JobResponse(
status="success",
message=f"{len(input_irregularities)} Irregularities listed in {video_irreg_2}")
# The following array will maintain only the Irregularities that can be processed by the model
processed_irregularities: list[Irregularity] = []
# Array of images corresponding to the Irregularities
images = []
yield JobResponse(status="success", message="Loading Irregularity Images...")
for irr in input_irregularities:
img_path = irr.image_URI
if os.path.exists(img_path):
# The Irregularity Image was found, then save its data
processed_irregularities.append(irr)
img = cv2.imread(img_path)
# All images must be resized to 224x224 due to network requirements
resized = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)
resized.astype('float32')
images.append(resized)
else:
yield JobResponse(
status="warning",
message=f"Irregularity Image not found! IrregularityID: {irr.irregularity_ID}")
yield JobResponse(
status="success",
message=f"{len(processed_irregularities)} Irregularities can be processed by the model")
images = prepare_images_for_prediction(images)
prediction = model.predict(images, batch_size=1)
# We are interested in the first array of prediction probabilities
probabilities = prediction[0]
# Get the position where the highest probability is stored
# It corresponds to the class index (0, 1, ...)
y = np.argmax(probabilities, axis=1)
predicted_labels = []
# The following array will correspond to the Irregularities that will be listed in the JSON
output_irregularities = []
for i in range(0, len(processed_irregularities)):
# Retrieve current Irregularity probability
irr_prob = probabilities[i][y[i]] * 100
predicted_labels.append(CLASSES[y[i]] + ' - ' + str(irr_prob) + ' %')
irr = processed_irregularities[i]
# Consider valid Irregularities only those with probability > prob
if probabilities[i][y[i]] * 100 > PROB:
irr.irregularity_type = IrregularityType(CLASSES[y[i]])
output_irregularities.append(irr)
else:
yield JobResponse(
status="warning",
message=f"Irregularity number {i+1} with probability {irr_prob:.2f} % discarded")
# TODO: Delete corresponding image from temp_path
# os.remove(irr['ImageURI'])
yield JobResponse(
status="success",
message=f"{len(output_irregularities)} Irregularities with probability > {PROB} %")
# TODO: IrregularityFileOutput1 shall present only Irregularities relevant to Tape Audio Restoration
File(tape_irreg_1, FileType.JSON)\
.write_content(
IrregularityFile(irregularities=output_irregularities)\
.to_json())
yield JobResponse(
status="success",
message=f"{len(output_irregularities)} Irregularities listed in {tape_irreg_1}")
File(tape_irreg_2, FileType.JSON)\
.write_content(
IrregularityFile(irregularities=output_irregularities, offset=audio_irreg_2.offset)\
.to_json())
yield JobResponse(
status="success",
message=f"{len(output_irregularities)} Irregularities listed in {tape_irreg_2}")
def serve(console):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
arp_pb2_grpc.add_AIMServicer_to_server(TapeIrregularityClassifierServicer(console), server)
server.add_insecure_port(f'[::]:{PORT}')
server.start()
server.wait_for_termination()
if __name__ == '__main__':
console = Console()
console.print(f'Server started at localhost:{PORT} :satellite:')
serve(console)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment