Skip to content

Commit

Permalink
Use spawn context for local driver
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Nov 26, 2024
1 parent 968a97c commit d683a6b
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 30 deletions.
3 changes: 3 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ extra_checks = True

exclude = (?x)(src/ert/resources | src/_ert/forward_model_runner)

[mypy-_ert.forward_model_runner.*]
ignore_errors = True

[mypy-scipy.*]
ignore_missing_imports = True

Expand Down
6 changes: 3 additions & 3 deletions src/_ert/forward_model_runner/job_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ def sigterm_handler(_signo, _stack_frame):
os.kill(0, signal.SIGTERM)


def main():
def main(argv):
os.nice(19)
signal.signal(signal.SIGTERM, sigterm_handler)
try:
job_runner_main(sys.argv)
job_runner_main(argv)
except Exception as e:
pgid = os.getpgid(os.getpid())
os.killpg(pgid, signal.SIGTERM)
raise e


if __name__ == "__main__":
main()
main(sys.argv)
1 change: 0 additions & 1 deletion src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,6 @@ async def _wait_for_stopped_server(self) -> None:
self._server_done.wait(), timeout=self.CLOSE_SERVER_TIMEOUT
)
except asyncio.TimeoutError:
print("Timeout server done")
self._server_done.set()

async def _monitor_and_handle_tasks(self) -> None:
Expand Down
51 changes: 25 additions & 26 deletions src/ert/scheduler/local_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import asyncio
import contextlib
import logging
import os
import multiprocessing
import signal
from asyncio.subprocess import Process
from contextlib import suppress
from pathlib import Path
from typing import MutableMapping, Optional, Set
Expand Down Expand Up @@ -36,6 +35,7 @@ async def submit(
realization_memory: Optional[int] = 0,
) -> None:
self._tasks[iens] = asyncio.create_task(self._run(iens, executable, *args))
self._spawn_context = multiprocessing.get_context("spawn")
with suppress(KeyError):
self._sent_finished_events.remove(iens)

Expand Down Expand Up @@ -73,6 +73,7 @@ async def _run(self, iens: int, executable: str, /, *args: str) -> None:
executable,
*args,
)
proc.start()
except FileNotFoundError as err:
# /bin/sh uses returncode 127 for FileNotFound, so copy that
# behaviour.
Expand All @@ -86,10 +87,12 @@ async def _run(self, iens: int, executable: str, /, *args: str) -> None:

returncode = 1
try:
returncode = await self._wait(proc)
logger.info(f"Realization {iens} finished with {returncode=}")
except asyncio.CancelledError:
returncode = await self._kill(proc)
while proc.is_alive():
proc.join(timeout=0.001)
await asyncio.sleep(1)
logger.info(f"Realization {iens} finished with exitcode={proc.exitcode}")
returncode = proc.exitcode if proc.exitcode is not None else 1
print(f"{proc.exitcode=},{returncode=}")
finally:
await self._dispatch_finished_event(iens, returncode)

Expand All @@ -99,35 +102,31 @@ async def _dispatch_finished_event(self, iens: int, returncode: int) -> None:
await self.event_queue.put(FinishedEvent(iens=iens, returncode=returncode))
self._sent_finished_events.add(iens)

@staticmethod
async def _init(iens: int, executable: str, /, *args: str) -> Process:
async def _init(
self, iens: int, executable: str, /, *args: str
) -> multiprocessing.Process:
"""This method exists to allow for mocking it in tests"""
return await asyncio.create_subprocess_exec(
executable,
*args,
preexec_fn=os.setpgrp,
)
from _ert.forward_model_runner.job_dispatch import main # noqa

@staticmethod
async def _wait(proc: Process) -> int:
return await proc.wait()
return self._spawn_context.Process(
target=main, args=[["job_dispatch.py", *args]]
) # type: ignore

@staticmethod
async def _kill(proc: Process) -> int:
try:
proc.terminate()
await asyncio.wait_for(proc.wait(), _TERMINATE_TIMEOUT)
except asyncio.TimeoutError:
async def _kill(proc: multiprocessing.Process) -> int:
proc.terminate()
proc.join(timeout=10)
if proc.exitcode is None:
proc.kill()
except ProcessLookupError:
# This will happen if the subprocess has not yet started
return signal.SIGTERM + SIGNAL_OFFSET
ret_val = await proc.wait()
proc.join(timeout=10)
# the returncode of a subprocess will be the negative signal value
# if it terminated due to a signal.
# https://docs.python.org/3/library/subprocess.html#subprocess.CompletedProcess.returncode
# we return SIGNAL_OFFSET + signal value to be in line with lfs/pbs drivers.
return -ret_val + SIGNAL_OFFSET
if proc.exitcode is not None:
return proc.exitcode + SIGNAL_OFFSET
else:
return -1 + SIGNAL_OFFSET

async def poll(self) -> None:
"""LocalDriver does not poll"""

0 comments on commit d683a6b

Please sign in to comment.