Commit 51d4ec30 authored by Matteo's avatar Matteo
Browse files

update

parent af2200e0
...@@ -40,7 +40,7 @@ install: ...@@ -40,7 +40,7 @@ install:
test: test:
$(POETRY) pytest $(POETRY) pytest
cd docs && $(POETRY) make doctest # cd docs && $(POETRY) make doctest
test-coverage: test-coverage:
$(POETRY) pytest --cov-config .coveragerc --cov-report term-missing --cov-report html --cov=mpai_cae_arp $(POETRY) pytest --cov-config .coveragerc --cov-report term-missing --cov-report html --cov=mpai_cae_arp
......
from mpai_cae_arp.audio._audio import AudioWave from mpai_cae_arp.audio._audio import AudioWave
from mpai_cae_arp.audio._noise import Noise from mpai_cae_arp.audio._noise import Noise
import mpai_cae_arp.audio.utils as utils from mpai_cae_arp.audio import utils
__all__ = ["AudioWave", "Noise", "utils"] __all__ = ["AudioWave", "Noise", "utils"]
...@@ -74,8 +74,7 @@ class AudioWave: ...@@ -74,8 +74,7 @@ class AudioWave:
samplerate: int samplerate: int
channels: int channels: int
def __init__(self, data: np.ndarray, bit: int, channels: int, def __init__(self, data: np.ndarray, bit: int, channels: int, samplerate: int):
samplerate: int):
if bit not in [8, 16, 24, 32]: if bit not in [8, 16, 24, 32]:
raise ValueError("bit must be 8, 16, 24 or 32") raise ValueError("bit must be 8, 16, 24 or 32")
if samplerate < 8000 or samplerate > 192000: if samplerate < 8000 or samplerate > 192000:
...@@ -92,8 +91,7 @@ class AudioWave: ...@@ -92,8 +91,7 @@ class AudioWave:
return iter(self.array) return iter(self.array)
def __getitem__(self, key): def __getitem__(self, key):
return AudioWave(self.array[key], self.bit, self.channels, return AudioWave(self.array[key], self.bit, self.channels, self.samplerate)
self.samplerate)
def __eq__(self, __o: object) -> bool: def __eq__(self, __o: object) -> bool:
if isinstance(__o, AudioWave): if isinstance(__o, AudioWave):
...@@ -149,8 +147,7 @@ class AudioWave: ...@@ -149,8 +147,7 @@ class AudioWave:
return bit, channels, samplerate return bit, channels, samplerate
@staticmethod @staticmethod
def buffer_generator_from_file(filepath: str, def buffer_generator_from_file(filepath: str, buffer_size: int = 1024 * 1024 * 8):
buffer_size: int = 1024 * 1024 * 8):
"""Return a generator that yields AudioWave objects from a file. The generator will read the file in chunks of `buffer_size` bytes. """Return a generator that yields AudioWave objects from a file. The generator will read the file in chunks of `buffer_size` bytes.
Parameters Parameters
...@@ -203,9 +200,7 @@ class AudioWave: ...@@ -203,9 +200,7 @@ class AudioWave:
""" """
data = np.array([ data = np.array([
int.from_bytes(raw_data[i:i + bit // 8], int.from_bytes(raw_data[i:i + bit // 8], byteorder='little', signed=True)
byteorder='little',
signed=True)
for i in range(0, len(raw_data), bit // 8) for i in range(0, len(raw_data), bit // 8)
]) ])
data = np.reshape(data, (-1, channels)) data = np.reshape(data, (-1, channels))
...@@ -243,8 +238,7 @@ class AudioWave: ...@@ -243,8 +238,7 @@ class AudioWave:
if force: if force:
os.makedirs(path.dirname(filepath)) os.makedirs(path.dirname(filepath))
else: else:
raise ValueError( raise ValueError(f"Directory {path.dirname(filepath)} does not exist")
f"Directory {path.dirname(filepath)} does not exist")
with wave.open(filepath, 'wb') as fp: with wave.open(filepath, 'wb') as fp:
fp.setframerate(self.samplerate) fp.setframerate(self.samplerate)
...@@ -253,6 +247,19 @@ class AudioWave: ...@@ -253,6 +247,19 @@ class AudioWave:
fp.setnframes(self.number_of_frames()) fp.setnframes(self.number_of_frames())
fp.writeframesraw(self.get_raw()) fp.writeframesraw(self.get_raw())
def set_sample_rate(self, samplerate: int):
"""Set the sample rate of the audio.
.. versionadded:: 0.4.0
Parameters
----------
samplerate: int
The new sample rate.
"""
self.array = librosa.resample(self.array, self.samplerate, samplerate)
self.samplerate = samplerate
def get_raw(self) -> bytes: def get_raw(self) -> bytes:
"""Get the raw data of the audio. """Get the raw data of the audio.
...@@ -312,9 +319,7 @@ class AudioWave: ...@@ -312,9 +319,7 @@ class AudioWave:
for channel in range(self.channels): for channel in range(self.channels):
signal = self.get_channel(channel).array signal = self.get_channel(channel).array
signal = signal / (2 ^ (self.bit - 1)) # normalize the signal signal = signal / (2 ^ (self.bit - 1)) # normalize the signal
mfccs = librosa.feature.mfcc(y=signal, mfccs = librosa.feature.mfcc(y=signal, sr=self.samplerate, n_mfcc=n_mfcc)
sr=self.samplerate,
n_mfcc=n_mfcc)
mean_mfccs = [] mean_mfccs = []
for e in mfccs: for e in mfccs:
mean_mfccs.append(np.mean(e)) mean_mfccs.append(np.mean(e))
......
from enum import Enum from enum import Enum
class EqualizationStandard(Enum): class EqualizationStandard(str, Enum):
IEC = "IEC" IEC = "IEC"
CCIR = "IEC1" CCIR = "IEC1"
NAB = "IEC2" NAB = "IEC2"
class SpeedStandard(Enum): class SpeedStandard(float, Enum):
I = 0.9375 I = 0.9375
II = 1.875 II = 1.875
III = 3.75 III = 3.75
......
...@@ -5,13 +5,13 @@ import yaml ...@@ -5,13 +5,13 @@ import yaml
from pydantic import BaseModel from pydantic import BaseModel
class FileAction(Enum): class FileAction(str, Enum):
READ = "r" READ = "r"
WRITE = "w" WRITE = "w"
APPEND = "a" APPEND = "a"
class FileType(Enum): class FileType(str, Enum):
YAML = "yaml" YAML = "yaml"
JSON = "json" JSON = "json"
......
"""
This module contains functions to print and format text in the console. Consider it as deprecated, use instead rich library.
"""
from enum import Enum from enum import Enum
END = '\033[0m' END = '\033[0m'
......
...@@ -10,28 +10,27 @@ from google.protobuf import symbol_database as _symbol_database ...@@ -10,28 +10,27 @@ from google.protobuf import symbol_database as _symbol_database
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\tarp.proto\x12\x03\x61rp\"+\n\x0bInfoRequest\x12\x12\n\x05\x66ield\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_field\"S\n\nJobRequest\x12\x13\n\x0bworking_dir\x18\x01 \x01(\t\x12\x12\n\nfiles_name\x18\x02 \x01(\t\x12\x12\n\x05index\x18\x03 \x01(\x05H\x00\x88\x01\x01\x42\x08\n\x06_index\".\n\x0bJobResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07message\x18\x02 \x01(\t\"&\n\x07\x43ontact\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05\x65mail\x18\x02 \x01(\t\"$\n\x07License\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0b\n\x03url\x18\x02 \x01(\t\"\x81\x01\n\x0cInfoResponse\x12\r\n\x05title\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x0f\n\x07version\x18\x03 \x01(\t\x12\x1d\n\x07\x63ontact\x18\x04 \x01(\x0b\x32\x0c.arp.Contact\x12\x1d\n\x07license\x18\x05 \x01(\x0b\x32\x0c.arp.License2f\n\x03\x41IM\x12\x30\n\x07getInfo\x12\x10.arp.InfoRequest\x1a\x11.arp.InfoResponse\"\x00\x12-\n\x04work\x12\x0f.arp.JobRequest\x1a\x10.arp.JobResponse\"\x00\x30\x01\x62\x06proto3'
)
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tarp.proto\x12\x03\x61rp\"+\n\x0bInfoRequest\x12\x12\n\x05\x66ield\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_field\"S\n\nJobRequest\x12\x13\n\x0bworking_dir\x18\x01 \x01(\t\x12\x12\n\nfiles_name\x18\x02 \x01(\t\x12\x12\n\x05index\x18\x03 \x01(\x05H\x00\x88\x01\x01\x42\x08\n\x06_index\".\n\x0bJobResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07message\x18\x02 \x01(\t\"&\n\x07\x43ontact\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05\x65mail\x18\x02 \x01(\t\"$\n\x07License\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0b\n\x03url\x18\x02 \x01(\t\"\x81\x01\n\x0cInfoResponse\x12\r\n\x05title\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x0f\n\x07version\x18\x03 \x01(\t\x12\x1d\n\x07\x63ontact\x18\x04 \x01(\x0b\x32\x0c.arp.Contact\x12\x1d\n\x07license\x18\x05 \x01(\x0b\x32\x0c.arp.License2f\n\x03\x41IM\x12\x30\n\x07getInfo\x12\x10.arp.InfoRequest\x1a\x11.arp.InfoResponse\"\x00\x12-\n\x04work\x12\x0f.arp.JobRequest\x1a\x10.arp.JobResponse\"\x00\x30\x01\x62\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'arp_pb2', globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'arp_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False: if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None DESCRIPTOR._options = None
_INFOREQUEST._serialized_start=18 _INFOREQUEST._serialized_start = 18
_INFOREQUEST._serialized_end=61 _INFOREQUEST._serialized_end = 61
_JOBREQUEST._serialized_start=63 _JOBREQUEST._serialized_start = 63
_JOBREQUEST._serialized_end=146 _JOBREQUEST._serialized_end = 146
_JOBRESPONSE._serialized_start=148 _JOBRESPONSE._serialized_start = 148
_JOBRESPONSE._serialized_end=194 _JOBRESPONSE._serialized_end = 194
_CONTACT._serialized_start=196 _CONTACT._serialized_start = 196
_CONTACT._serialized_end=234 _CONTACT._serialized_end = 234
_LICENSE._serialized_start=236 _LICENSE._serialized_start = 236
_LICENSE._serialized_end=272 _LICENSE._serialized_end = 272
_INFORESPONSE._serialized_start=275 _INFORESPONSE._serialized_start = 275
_INFORESPONSE._serialized_end=404 _INFORESPONSE._serialized_end = 404
_AIM._serialized_start=406 _AIM._serialized_start = 406
_AIM._serialized_end=508 _AIM._serialized_end = 508
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)
...@@ -15,15 +15,15 @@ class AIMStub(object): ...@@ -15,15 +15,15 @@ class AIMStub(object):
channel: A grpc.Channel. channel: A grpc.Channel.
""" """
self.getInfo = channel.unary_unary( self.getInfo = channel.unary_unary(
'/arp.AIM/getInfo', '/arp.AIM/getInfo',
request_serializer=arp__pb2.InfoRequest.SerializeToString, request_serializer=arp__pb2.InfoRequest.SerializeToString,
response_deserializer=arp__pb2.InfoResponse.FromString, response_deserializer=arp__pb2.InfoResponse.FromString,
) )
self.work = channel.unary_stream( self.work = channel.unary_stream(
'/arp.AIM/work', '/arp.AIM/work',
request_serializer=arp__pb2.JobRequest.SerializeToString, request_serializer=arp__pb2.JobRequest.SerializeToString,
response_deserializer=arp__pb2.JobResponse.FromString, response_deserializer=arp__pb2.JobResponse.FromString,
) )
class AIMServicer(object): class AIMServicer(object):
...@@ -48,56 +48,60 @@ class AIMServicer(object): ...@@ -48,56 +48,60 @@ class AIMServicer(object):
def add_AIMServicer_to_server(servicer, server): def add_AIMServicer_to_server(servicer, server):
rpc_method_handlers = { rpc_method_handlers = {
'getInfo': grpc.unary_unary_rpc_method_handler( 'getInfo':
servicer.getInfo, grpc.unary_unary_rpc_method_handler(
request_deserializer=arp__pb2.InfoRequest.FromString, servicer.getInfo,
response_serializer=arp__pb2.InfoResponse.SerializeToString, request_deserializer=arp__pb2.InfoRequest.FromString,
), response_serializer=arp__pb2.InfoResponse.SerializeToString,
'work': grpc.unary_stream_rpc_method_handler( ),
servicer.work, 'work':
request_deserializer=arp__pb2.JobRequest.FromString, grpc.unary_stream_rpc_method_handler(
response_serializer=arp__pb2.JobResponse.SerializeToString, servicer.work,
), request_deserializer=arp__pb2.JobRequest.FromString,
response_serializer=arp__pb2.JobResponse.SerializeToString,
),
} }
generic_handler = grpc.method_handlers_generic_handler( generic_handler = grpc.method_handlers_generic_handler('arp.AIM',
'arp.AIM', rpc_method_handlers) rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,)) server.add_generic_rpc_handlers((generic_handler, ))
# This class is part of an EXPERIMENTAL API. # This class is part of an EXPERIMENTAL API.
class AIM(object): class AIM(object):
"""Missing associated documentation comment in .proto file.""" """Missing associated documentation comment in .proto file."""
@staticmethod @staticmethod
def getInfo(request, def getInfo(request,
target, target,
options=(), options=(),
channel_credentials=None, channel_credentials=None,
call_credentials=None, call_credentials=None,
insecure=False, insecure=False,
compression=None, compression=None,
wait_for_ready=None, wait_for_ready=None,
timeout=None, timeout=None,
metadata=None): metadata=None):
return grpc.experimental.unary_unary(request, target, '/arp.AIM/getInfo', return grpc.experimental.unary_unary(request, target, '/arp.AIM/getInfo',
arp__pb2.InfoRequest.SerializeToString, arp__pb2.InfoRequest.SerializeToString,
arp__pb2.InfoResponse.FromString, arp__pb2.InfoResponse.FromString, options,
options, channel_credentials, channel_credentials, insecure,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) call_credentials, compression,
wait_for_ready, timeout, metadata)
@staticmethod @staticmethod
def work(request, def work(request,
target, target,
options=(), options=(),
channel_credentials=None, channel_credentials=None,
call_credentials=None, call_credentials=None,
insecure=False, insecure=False,
compression=None, compression=None,
wait_for_ready=None, wait_for_ready=None,
timeout=None, timeout=None,
metadata=None): metadata=None):
return grpc.experimental.unary_stream(request, target, '/arp.AIM/work', return grpc.experimental.unary_stream(request, target, '/arp.AIM/work',
arp__pb2.JobRequest.SerializeToString, arp__pb2.JobRequest.SerializeToString,
arp__pb2.JobResponse.FromString, arp__pb2.JobResponse.FromString, options,
options, channel_credentials, channel_credentials, insecure,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) call_credentials, compression,
wait_for_ready, timeout, metadata)
"""
This module contains functions to convert time formats. The agreed time format is hh:mm:ss.msc as a string in MPAI standards.
"""
import datetime import datetime
...@@ -93,11 +96,9 @@ def frames_to_seconds(frames: int, fps: int) -> float: ...@@ -93,11 +96,9 @@ def frames_to_seconds(frames: int, fps: int) -> float:
""" """
if fps <= 0: if fps <= 0:
raise ValueError( raise ValueError("The number of frames per second must be greater than 0")
"The number of frames per second must be greater than 0")
if frames < 0: if frames < 0:
raise ValueError( raise ValueError("The number of frames must be greater than or equal to 0")
"The number of frames must be greater than or equal to 0")
return frames / fps return frames / fps
...@@ -126,10 +127,8 @@ def seconds_to_frames(seconds: float, fps: int) -> int: ...@@ -126,10 +127,8 @@ def seconds_to_frames(seconds: float, fps: int) -> int:
""" """
if fps <= 0: if fps <= 0:
raise ValueError( raise ValueError("The number of frames per second must be greater than 0")
"The number of frames per second must be greater than 0")
if seconds < 0: if seconds < 0:
raise ValueError( raise ValueError("The number of seconds must be greater than or equal to 0")
"The number of seconds must be greater than or equal to 0")
return int(seconds * fps) return int(seconds * fps)
from .irregularity import Irregularity, IrregularityFile, IrregularityProperties, IrregularityType
from .restoration import Restoration, EditingList
from . import schema
__all__ = [
'Irregularity', 'IrregularityFile', 'IrregularityProperties', 'IrregularityType',
'Restoration', 'EditingList', 'schema'
]
from typing import TypeVar
import uuid import uuid
from enum import Enum from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel, Field
from mpai_cae_arp.audio.standards import EqualizationStandard, SpeedStandard from mpai_cae_arp.audio.standards import EqualizationStandard, SpeedStandard
from mpai_cae_arp.files import File, FileType
from mpai_cae_arp.time import time_to_seconds
class IrregularityType(Enum): class IrregularityType(str, Enum):
BRANDS_ON_TAPE = "b" BRANDS_ON_TAPE = "b"
SPLICE = "sp" SPLICE = "sp"
START_OF_TAPE = "sot" START_OF_TAPE = "sot"
...@@ -22,20 +25,28 @@ class IrregularityType(Enum): ...@@ -22,20 +25,28 @@ class IrregularityType(Enum):
BACKWARD = "sb" BACKWARD = "sb"
class Source(Enum): class Source(str, Enum):
AUDIO = "a" AUDIO = "a"
VIDEO = "v" VIDEO = "v"
BOTH = "b" BOTH = "b"
SelfIrregularity = TypeVar("SelfIrregularity", bound="Irregularity")
SelfIrregularityProperties = TypeVar("SelfIrregularityProperties",
bound="IrregularityProperties")
SelfIrregularityFile = TypeVar("SelfIrregularityFile", bound="IrregularityFile")
class IrregularityProperties(BaseModel): class IrregularityProperties(BaseModel):
reading_speed: SpeedStandard reading_speed: SpeedStandard = Field(alias="ReadingSpeedStandard")
reading_equalisation: EqualizationStandard reading_equalisation: EqualizationStandard = Field(
writing_speed: SpeedStandard alias="ReadingEqualisationStandard")
writing_equalisation: EqualizationStandard writing_speed: SpeedStandard = Field(alias="WritingSpeedStandard")
writing_equalisation: EqualizationStandard = Field(
alias="WritingEqualisationStandard")
@staticmethod @staticmethod
def from_json(json_property: dict): def from_json(json_property: dict) -> SelfIrregularityProperties:
return IrregularityProperties( return IrregularityProperties(
reading_speed=SpeedStandard(json_property["ReadingSpeedStandard"]), reading_speed=SpeedStandard(json_property["ReadingSpeedStandard"]),
reading_equalisation=EqualizationStandard( reading_equalisation=EqualizationStandard(
...@@ -44,7 +55,11 @@ class IrregularityProperties(BaseModel): ...@@ -44,7 +55,11 @@ class IrregularityProperties(BaseModel):
writing_equalisation=EqualizationStandard( writing_equalisation=EqualizationStandard(
json_property["WritingEqualisationStandard"])) json_property["WritingEqualisationStandard"]))
def to_json(self): def to_json(self) -> dict:
"""
.. deprecated:: 0.4.0
Use :meth:`IrregularityProperties.json` instead.
"""
return { return {
"ReadingSpeedStandard": self.reading_speed.value, "ReadingSpeedStandard": self.reading_speed.value,
"ReadingEqualisationStandard": self.reading_equalisation.value, "ReadingEqualisationStandard": self.reading_equalisation.value,
...@@ -63,16 +78,19 @@ class IrregularityProperties(BaseModel): ...@@ -63,16 +78,19 @@ class IrregularityProperties(BaseModel):
class Irregularity(BaseModel): class Irregularity(BaseModel):
irregularity_ID: uuid.UUID irregularity_ID: uuid.UUID = Field(default_factory=uuid.uuid4,
source: Source alias="IrregularityID")
time_label: str source: Source = Field(alias="Source")
irregularity_type: IrregularityType | None = None time_label: str = Field(alias="TimeLabel")
irregularity_properties: IrregularityProperties | None = None irregularity_type: IrregularityType | None = Field(default=None,
image_URI: str | None = None alias="IrregularityType")
audio_block_URI: str | None = None irregularity_properties: IrregularityProperties | None = Field(
default=None, alias="IrregularityProperties")
image_URI: str | None = Field(default=None, alias="ImageURI")
audio_block_URI: str | None = Field(default=None, alias="AudioBlockURI")
@staticmethod @staticmethod
def from_json(json_irreg: dict): def from_json(json_irreg: dict) -> SelfIrregularity:
properties = None properties = None
if json_irreg.get("IrregularityProperties") is not None: if json_irreg.get("IrregularityProperties") is not None:
...@@ -82,9 +100,8 @@ class Irregularity(BaseModel): ...@@ -82,9 +100,8 @@ class Irregularity(BaseModel):
raw_irreg_type = json_irreg.get("IrregularityType") raw_irreg_type = json_irreg.get("IrregularityType")
irregularity_type = None irregularity_type = None
if raw_irreg_type is not None: if raw_irreg_type is not None:
if raw_irreg_type is not ( if raw_irreg_type is not (IrregularityType.SPEED.value
IrregularityType.SPEED.value or IrregularityType.SPEED_AND_EQUALIZATION.value):
or IrregularityType.SPEED_AND_EQUALIZATION.value):
irregularity_type = IrregularityType(raw_irreg_type) irregularity_type = IrregularityType(raw_irreg_type)
else: else:
if properties.reading_equalisation != properties.writing_equalisation: if properties.reading_equalisation != properties.writing_equalisation:
...@@ -92,8 +109,7 @@ class Irregularity(BaseModel): ...@@ -92,8 +109,7 @@ class Irregularity(BaseModel):
else: else:
irregularity_type = IrregularityType.SPEED irregularity_type = IrregularityType.SPEED
return Irregularity(irregularity_ID=uuid.UUID( return Irregularity(irregularity_ID=uuid.UUID(json_irreg["IrregularityID"]),
json_irreg["IrregularityID"]),
source=Source(json_irreg["Source"]), source=Source(json_irreg["Source"]),
time_label=json_irreg["TimeLabel"], time_label=json_irreg["TimeLabel"],
irregularity_type=irregularity_type, irregularity_type=irregularity_type,
...@@ -101,7 +117,13 @@ class Irregularity(BaseModel): ...@@ -101,7 +117,13 @@ class Irregularity(BaseModel):
image_URI=json_irreg.get("ImageURI"), image_URI=json_irreg.get("ImageURI"),
audio_block_URI=json_irreg.get("AudioBlockURI")) audio_block_URI=json_irreg.get("AudioBlockURI"))
def to_json(self): def to_json(self) -> dict:
"""
Returns a dictionary with the irregularity information
.. deprecated:: 0.4.0
Use :func:`Irregularity.json` instead.
"""
dictionary = { dictionary = {
"IrregularityID": str(self.irregularity_ID), "IrregularityID": str(self.irregularity_ID),
"Source": self.source.value, "Source": self.source.value,
...@@ -118,44 +140,16 @@ class Irregularity(BaseModel): ...@@ -118,44 +140,16 @@ class Irregularity(BaseModel):
dictionary["AudioBlockURI"] = self.audio_block_URI dictionary["AudioBlockURI"] = self.audio_block_URI
if self.irregularity_properties: if self.irregularity_properties:
dictionary[ dictionary["IrregularityProperties"] = self.irregularity_properties.to_json(
"IrregularityProperties"] = self.irregularity_properties.to_json( )
)
return dictionary return dictionary
class IrregularityFile(BaseModel): class IrregularityFile(BaseModel):
# TODO: the offset calculation is not implemented yet, so it is set to None
irregularities: list[Irregularity] irregularities: list[Irregularity]
offset: int | None = None offset: int | None = None
class Config:
schema_extra = {
"example": {
"offset":
0,
"irregularities": [{
"irregularity_ID":
"a0a0a0a0-a0a0-a0a0-a0a0-a0a0a0a0a0a0",
"source":
"a",
"time_label":
"00:00:00:00",
"irregularity_type":
"b",
"irregularity_properties": {
"reading_speed": "n",
"reading_equalisation": "n",
"writing_speed": "n",
"writing_equalisation": "n"
},
"audio_block_URI":
"https://example.com/audio.wav",
}]
}
}
def __eq__(self, __o: object) -> bool: def __eq__(self, __o: object) -> bool:
if not isinstance(__o, IrregularityFile): if not isinstance(__o, IrregularityFile):
return False return False
...@@ -163,7 +157,7 @@ class IrregularityFile(BaseModel): ...@@ -163,7 +157,7 @@ class IrregularityFile(BaseModel):
return self.irregularities == __o.irregularities and self.offset == __o.offset return self.irregularities == __o.irregularities and self.offset == __o.offset
@staticmethod @staticmethod
def from_json(json_irreg: dict): def from_json(json_irreg: dict) -> SelfIrregularityFile:
irregularities = [] irregularities = []
for irreg in json_irreg["Irregularities"]: for irreg in json_irreg["Irregularities"]:
...@@ -172,7 +166,7 @@ class IrregularityFile(BaseModel): ...@@ -172,7 +166,7 @@ class IrregularityFile(BaseModel):
return IrregularityFile(irregularities=irregularities, return IrregularityFile(irregularities=irregularities,
offset=json_irreg.get("Offset")) offset=json_irreg.get("Offset"))
def to_json(self): def to_json(self) -> dict:
dictionary = { dictionary = {
"Irregularities": "Irregularities":
[irregularity.to_json() for irregularity in self.irregularities], [irregularity.to_json() for irregularity in self.irregularities],
...@@ -183,7 +177,7 @@ class IrregularityFile(BaseModel): ...@@ -183,7 +177,7 @@ class IrregularityFile(BaseModel):
return dictionary return dictionary
def add(self, irregularity: Irregularity): def add(self, irregularity: Irregularity) -> SelfIrregularityFile:
"""Add an irregularity to the list of irregularities. """Add an irregularity to the list of irregularities.
Parameters Parameters
...@@ -197,12 +191,17 @@ class IrregularityFile(BaseModel): ...@@ -197,12 +191,17 @@ class IrregularityFile(BaseModel):
if the irregularity is not a py:class:`Irregularity` object if the irregularity is not a py:class:`Irregularity` object
""" """
if not isinstance(irregularity, Irregularity): if not isinstance(irregularity, Irregularity):
raise TypeError( raise TypeError("IrregularityFile.add() expects an Irregularity object")
"IrregularityFile.add() expects an Irregularity object")
self.irregularities.append(irregularity) self.irregularities.append(irregularity)
self.order()
return self
def join(self, other): def order(self) -> SelfIrregularityFile:
"""Append the irregularities of other in current irregularity file. self.irregularities.sort(key=lambda x: time_to_seconds(x.time_label))
return self
def join(self, other) -> SelfIrregularityFile:
"""Append the irregularities of other in current irregularity file, ordered by time.
Parameters Parameters
---------- ----------
...@@ -217,3 +216,13 @@ class IrregularityFile(BaseModel): ...@@ -217,3 +216,13 @@ class IrregularityFile(BaseModel):
if not isinstance(other, IrregularityFile): if not isinstance(other, IrregularityFile):
raise TypeError("other must be an instance of IrregularityFile") raise TypeError("other must be an instance of IrregularityFile")
self.irregularities += other.irregularities self.irregularities += other.irregularities
self.order()
return self
def save_as_json_file(self, path: str) -> None:
"""
Save the editing list as a JSON file at the given path.
.. versionadded:: 0.4.0
"""
File(path=path, file_type=FileType.JSON).write_content(self.json())
import uuid
from typing import TypeVar
from pydantic import BaseModel, Field
from mpai_cae_arp.audio.standards import SpeedStandard, EqualizationStandard
from mpai_cae_arp.files import File, FileType
class Restoration(BaseModel):
id: uuid.UUID = Field(default_factory=uuid.uuid4)
preservation_audio_file_start: str
preservation_audio_file_end: str
restored_audio_file_URI: str
reading_backwards: bool
applied_speed_standard: SpeedStandard
applied_sample_frequency: int
original_equalization_standard: EqualizationStandard
Self = TypeVar("Self", bound="EditingList")
class EditingList(BaseModel):
"""
.. versionadded:: 0.4.0
"""
original_speed_standard: SpeedStandard
original_equalization_standard: EqualizationStandard
original_sample_frequency: int
restorations: list[Restoration]
def add(self, restoration: Restoration) -> Self:
self.restorations.append(restoration)
return self
def remove(self, restoration: Restoration) -> Self:
self.restorations.remove(restoration)
return self
def remove_by_id(self, restoration_id: uuid.UUID) -> Self:
filtered = list(filter(lambda r: r.id != restoration_id, self.restorations))
self.restorations = filtered
return self
def save_as_json_file(self, path: str) -> None:
File(path=path, file_type=FileType.JSON).write_content(self.json())
if __name__ == "__main__":
rest = Restoration(preservation_audio_file_start="00:00:00.000",
preservation_audio_file_end="00:00:10.000",
restored_audio_file_URI="https://www.google.com",
reading_backwards=False,
applied_sample_frequency=44100,
applied_speed_standard=SpeedStandard.III,
original_equalization_standard=EqualizationStandard.CCIR)
editing_list = EditingList(original_equalization_standard=EqualizationStandard.CCIR,
original_speed_standard=SpeedStandard.III,
original_sample_frequency=44100,
restorations=[])
editing_list.add(rest)
print(rest)
print(editing_list.json())
# pylint: disable=too-few-public-methods
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
...@@ -6,12 +7,7 @@ class Contact(BaseModel): ...@@ -6,12 +7,7 @@ class Contact(BaseModel):
email: str = Field(..., description="Email of the contact person.") email: str = Field(..., description="Email of the contact person.")
class Config: class Config:
schema_extra = { schema_extra = {"example": {"name": "John Doe", "email": "email@email.com"}}
"example": {
"name": "John Doe",
"email": "email@email.com"
}
}
class License(BaseModel): class License(BaseModel):
...@@ -33,8 +29,7 @@ class Info(BaseModel): ...@@ -33,8 +29,7 @@ class Info(BaseModel):
version: str = Field(..., description="The version of the API.") version: str = Field(..., description="The version of the API.")
contact: Contact = Field( contact: Contact = Field(
..., description="Contact information for the owners of the API.") ..., description="Contact information for the owners of the API.")
license_info: License = Field( license_info: License = Field(..., description="License information for the API.")
..., description="License information for the API.")
class Config: class Config:
schema_extra = { schema_extra = {
......
[tool.poetry] [tool.poetry]
name = "mpai-cae-arp" name = "mpai-cae-arp"
version = "0.3.2" version = "0.4.0"
description = "The MPAI CAE-ARP software API" description = "The MPAI CAE-ARP software API"
authors = ["Matteo Spanio <dev2@audioinnova.com>"] authors = ["Matteo Spanio <dev2@audioinnova.com>"]
readme = "README.md" readme = "README.md"
...@@ -23,7 +23,6 @@ pytest-cov = "^4.0.0" ...@@ -23,7 +23,6 @@ pytest-cov = "^4.0.0"
pytest-xdist = "^3.2.1" pytest-xdist = "^3.2.1"
toml = "^0.10.2" toml = "^0.10.2"
[tool.poetry.group.docs.dependencies] [tool.poetry.group.docs.dependencies]
sphinx = "^6.1.3" sphinx = "^6.1.3"
...@@ -41,16 +40,16 @@ relative_files = true ...@@ -41,16 +40,16 @@ relative_files = true
[tool.yapf] [tool.yapf]
blank_line_before_nested_class_or_def = true blank_line_before_nested_class_or_def = true
column_limit = 80 column_limit = 88
[tool.pylint] [tool.pylint]
max-line-length = 80 max-line-length = 88
extension-pkg-whitelist=['pydantic']
disable = [ disable = [
"C0103", # Invalid name "C0103", # Invalid name
"C0114", # Missing module docstring "C0114", # Missing module docstring
"C0115", # Missing class docstring "C0115", # Missing class docstring
"C0116", # Missing function or method docstring "C0116", # Missing function or method docstring
"C0301", # Line too long "C0301", # Line too long
"W0102", # Dangerous default value
"E1101", # Module has no member "E1101", # Module has no member
] ]
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment