import os from typing import Dict, Union import docker from docker.types import Mount from typeguard import typechecked try: from common_utils import msg_builder, rabbitmq except ModuleNotFoundError: from common_module.common_utils import msg_builder, rabbitmq try: from common_utils.logger import create_logger except ModuleNotFoundError: from common_module.common_utils.logger import create_logger # Docker client docker_client = docker.from_env() docker_prefix = os.environ["DOCKER_PREFIX"] LOGGER = create_logger(__name__) @typechecked def mod_envs(module: str) -> Dict[str, str]: """ Returns envs """ envs = { "LOG_LEVEL": os.environ["LOG_LEVEL"], "AI_FW_DIR": os.environ["AI_FW_DIR"], "MIDDLEWARE_USER": os.environ["MIDDLEWARE_USER"], "MIDDLEWARE_PASSWORD": os.environ["MIDDLEWARE_PASSWORD"], "MIDDLEWARE_VIRTUALHOST": os.environ["MIDDLEWARE_VIRTUALHOST"], "MIDDLEWARE_PORT": os.environ["MIDDLEWARE_PORT"], "GIT_NAME": os.environ["GIT_NAME"], "GIT_TOKEN": os.environ["GIT_TOKEN"], } if module == "mmc_aus": envs["HUGGINGFACE_TOKEN"] = os.environ["HUGGINGFACE_TOKEN"] return envs @typechecked def mod_vols(module: str) -> list: """ Returns vols """ # read vols # key: path on my machine/host # value of key "bind": path inside container # "ro": read only # "rw": read & write # https://stackoverflow.com/a/74524696 # Docker prepends dir name to named vols to prevent clashes w existing containers! # https://forums.docker.com/t/docker-compose-prepends-directory-name-to-named-volumes/32835/ # Mount(path_in_container, path_in_host, type) storage_mnt = Mount( target=os.environ["AI_FW_DIR"], source=os.environ["PATH_SHARED"], type="bind" ) log_mnt = Mount( target="/LOGS", source=f"{docker_prefix}_{module}_logs", type="volume" ) return [storage_mnt, log_mnt] @typechecked def get_rabbitmq_health() -> str: """ Checks if RabbitMQ is healthy. """ # get RabbitMQ container by name rabbitmq_health = docker_client.containers.get( f"{docker_prefix}-rabbitmq_dc-1" ).health LOGGER.debug(f"RabbitMQ is {rabbitmq_health}") return rabbitmq_health @typechecked def run_container(module: str) -> str: """ Runs container. Returns container ID if RabbitMQ is healthy. """ # check if RabbitMQ is healthy if get_rabbitmq_health() != "healthy": raise Exception("RabbitMQ is unhealthy") # network: name of the network the container will be connected to at creation time # Docker prepends dir name to network! container = docker_client.containers.run( image=f"{module}:{os.environ['TAG']}", detach=True, auto_remove=False, environment=mod_envs(module), mounts=mod_vols(module), device_requests=[ docker.types.DeviceRequest(device_ids=["0"], capabilities=[["gpu"]]) ], network=f"{docker_prefix}_aifw_net", name=f"{docker_prefix}-{module}-1", ) return container.id @typechecked def rm_container(container_id: str): """ Kills & rm's container. """ # get container running_containers = docker_client.containers.list(filters={"id": container_id}) for container in running_containers: # kill container container.kill() # rm container container.remove() @typechecked def run(message_body: dict, worker: rabbitmq.Worker) -> bool: """ ctrler's msg_builder.build_msg. Returns False. """ if not msg_builder.validate_message( message_body, ["external_id", "application", "uid", "job_status", "process_status"], ): LOGGER.error(f"bad msg: {message_body=}") else: # default vals uid = message_body["programme"]["uid"] force = False queue_name = "" if "force" in message_body["programme"]: force_v = message_body["programme"]["force"] LOGGER.debug(f"{type(force_v)=}") if force_v is not None and ( (isinstance(force_v, bool) and force_v) or (isinstance(force_v, str) and force_v.lower() == "true") ): force = True LOGGER.debug(f"{force=}") if "container_id" in message_body["programme"]: doomed_container_id = message_body["programme"]["container_id"] LOGGER.debug(f"tryna kill {doomed_container_id=}...") rm_container(doomed_container_id) if message_body["programme"]["job_status"] == "start": queue_name = "queue_module_osd_tvs" container_id = run_container("osd_tvs") else: # job_status = "working" # read val of key "module" if message_body["programme"]["process_status"] == "completed": if message_body["programme"]["module"] == "osd_tvs": queue_name = "queue_module_mmc_aus" container_id = run_container("mmc_aus") elif message_body["programme"]["module"] == "mmc_aus": queue_name = "queue_module_osd_vcd" container_id = run_container("osd_vcd") elif message_body["programme"]["module"] == "osd_vcd": queue_name = "queue_module_mmc_sir" container_id = run_container("mmc_sir") elif message_body["programme"]["module"] == "mmc_sir": queue_name = "queue_module_mmc_asr" container_id = run_container("mmc_asr") elif message_body["programme"]["module"] == "mmc_asr": queue_name = "queue_module_paf_fir" container_id = run_container("paf_fir") elif message_body["programme"]["module"] == "paf_fir": queue_name = "queue_module_osd_ava" container_id = run_container("osd_ava") elif message_body["programme"]["module"] == "osd_ava": queue_name = "queue_module_osd_avs" container_id = run_container("osd_avs") elif message_body["programme"]["module"] == "osd_avs": queue_name = "queue_module_osd_ave" container_id = run_container("osd_ave") LOGGER.debug( "module " f"{message_body['programme']['module']} " f"{message_body['programme']['process_status']} " "response" ) LOGGER.debug(f"{queue_name=}") if queue_name != "": # send data to analyzer s_job_status = "--" if "job_status" in message_body["programme"]: s_job_status = message_body["programme"]["job_status"] s_process_status = "--" if "process_status" in message_body["programme"]: s_process_status = message_body["programme"]["process_status"] LOGGER.debug( f"[TRACE][{uid}][SEND] queue: {queue_name}" f" -- job_status: {s_job_status}" f" -- process_status: {s_process_status}" ) message_body["programme"]["force"] = force message_body["programme"]["container_id"] = container_id worker.send_messages(queue=queue_name, messages=(message_body,)) else: if "process_status" in message_body["programme"]: if not ( "working" in message_body["programme"]["process_status"] or "failed" in message_body["programme"]["process_status"] ): LOGGER.warn( f"bad msg: {message_body['programme']['process_status']=}" ) return False