Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use spawn context for local driver #9336

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ extra_checks = True

exclude = (?x)(src/ert/resources | src/_ert/forward_model_runner)

[mypy-_ert.forward_model_runner.*]
ignore_errors = True

[mypy-scipy.*]
ignore_missing_imports = True

Expand Down
6 changes: 3 additions & 3 deletions src/_ert/forward_model_runner/job_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ def sigterm_handler(_signo, _stack_frame):
os.kill(0, signal.SIGTERM)


def main():
def main(argv):
os.nice(19)
signal.signal(signal.SIGTERM, sigterm_handler)
try:
job_runner_main(sys.argv)
job_runner_main(argv)
except Exception as e:
pgid = os.getpgid(os.getpid())
os.killpg(pgid, signal.SIGTERM)
raise e


if __name__ == "__main__":
main()
main(sys.argv)
1 change: 0 additions & 1 deletion src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,6 @@ async def _wait_for_stopped_server(self) -> None:
self._server_done.wait(), timeout=self.CLOSE_SERVER_TIMEOUT
)
except asyncio.TimeoutError:
print("Timeout server done")
self._server_done.set()

async def _monitor_and_handle_tasks(self) -> None:
Expand Down
75 changes: 44 additions & 31 deletions src/ert/scheduler/local_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import asyncio
import contextlib
import logging
import os
import multiprocessing
import signal
from asyncio.subprocess import Process
from contextlib import suppress
from pathlib import Path
from typing import MutableMapping, Optional, Set
Expand All @@ -23,6 +22,7 @@
super().__init__()
self._tasks: MutableMapping[int, asyncio.Task[None]] = {}
self._sent_finished_events: Set[int] = set()
self._spawn_context = multiprocessing.get_context("spawn")

async def submit(
self,
Expand Down Expand Up @@ -82,12 +82,16 @@
await self._dispatch_finished_event(iens, 127)
return

await self.event_queue.put(StartedEvent(iens=iens))

returncode = 1
try:
returncode = await self._wait(proc)
logger.info(f"Realization {iens} finished with {returncode=}")
proc.start()
await self.event_queue.put(StartedEvent(iens=iens))

returncode = 1
while proc.is_alive():
proc.join(timeout=0.001)
await asyncio.sleep(1)
logger.info(f"Realization {iens} finished with exitcode={proc.exitcode}")
returncode = proc.exitcode if proc.exitcode is not None else 1
except asyncio.CancelledError:
returncode = await self._kill(proc)
finally:
Expand All @@ -99,35 +103,44 @@
await self.event_queue.put(FinishedEvent(iens=iens, returncode=returncode))
self._sent_finished_events.add(iens)

@staticmethod
async def _init(iens: int, executable: str, /, *args: str) -> Process:
async def _init(
self, iens: int, executable: str, /, *args: str
) -> multiprocessing.Process:
"""This method exists to allow for mocking it in tests"""
return await asyncio.create_subprocess_exec(
executable,
*args,
preexec_fn=os.setpgrp,
)

@staticmethod
async def _wait(proc: Process) -> int:
return await proc.wait()

@staticmethod
async def _kill(proc: Process) -> int:
try:
proc.terminate()
await asyncio.wait_for(proc.wait(), _TERMINATE_TIMEOUT)
except asyncio.TimeoutError:
proc.kill()
except ProcessLookupError:
# This will happen if the subprocess has not yet started
return signal.SIGTERM + SIGNAL_OFFSET
ret_val = await proc.wait()
from _ert.forward_model_runner.job_dispatch import main # noqa

return self._spawn_context.Process(
target=main, args=[["job_dispatch.py", *args]]
) # type: ignore

@classmethod
async def _wait(
cls,
proc: multiprocessing.Process,
max_wait: int = 10,
blocked_wait: float = 0.001,
sleep_interval: int = 1,
) -> None:
wait = 0
while wait < max_wait and proc.exitcode is None:
proc.join(timeout=blocked_wait)
await asyncio.sleep(sleep_interval)
wait += blocked_wait + sleep_interval

Check failure on line 128 in src/ert/scheduler/local_driver.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Incompatible types in assignment (expression has type "float", variable has type "int")

@classmethod
async def _kill(cls, proc: multiprocessing.Process) -> int:
proc.terminate()
await cls._wait(proc)
proc.kill()
await cls._wait(proc)
# the returncode of a subprocess will be the negative signal value
# if it terminated due to a signal.
# https://docs.python.org/3/library/subprocess.html#subprocess.CompletedProcess.returncode
# we return SIGNAL_OFFSET + signal value to be in line with lfs/pbs drivers.
return -ret_val + SIGNAL_OFFSET
if proc.exitcode is not None:
return proc.exitcode + SIGNAL_OFFSET
else:
return -1 + SIGNAL_OFFSET

async def poll(self) -> None:
"""LocalDriver does not poll"""
Loading