Skip to content

Commit

Permalink
Simplify ensemble evaluator shutdown timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
JHolba committed Nov 26, 2024
1 parent 4b2c2f5 commit 09a0a42
Showing 1 changed file with 34 additions and 45 deletions.
79 changes: 34 additions & 45 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,56 +386,45 @@ async def _start_running(self) -> None:

CLOSE_SERVER_TIMEOUT = 60

async def _wait_for_stopped_server(self) -> None:
"""
When the ensemble is done, we wait for the server to stop
with a timeout.
"""
try:
await asyncio.wait_for(
self._server_done.wait(), timeout=self.CLOSE_SERVER_TIMEOUT
)
except asyncio.TimeoutError:
print("Timeout server done")
self._server_done.set()

async def _monitor_and_handle_tasks(self) -> None:
pending: Iterable[asyncio.Task[None]] = self._ee_tasks
stop_timeout_task: Optional[asyncio.Task[None]] = None
timeout = None

while True:
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
try:
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED, timeout=timeout
)
for task in done:
if task_exception := task.exception():
self.log_exception(task_exception, task.get_name())
raise task_exception
elif task.get_name() == "server_task":
return
elif task.get_name() == "ensemble_task":
timeout = self.CLOSE_SERVER_TIMEOUT
continue
else:
msg = f"Something went wrong, {task.get_name()} is done prematurely!"
logger.error(msg)
raise RuntimeError(msg)
except asyncio.TimeoutError:
print("Timeout server done")
self._server_done.set()

@staticmethod
def log_exception(task_exception: BaseException, task_name: str) -> None:
exc_traceback = "".join(
traceback.format_exception(
None, task_exception, task_exception.__traceback__
)
for task in done:
if task_exception := task.exception():
exc_traceback = "".join(
traceback.format_exception(
None, task_exception, task_exception.__traceback__
)
)
logger.error(
(
f"Exception in evaluator task {task.get_name()}: {task_exception}\n"
f"Traceback: {exc_traceback}"
)
)
raise task_exception
elif task.get_name() == "server_task":
if stop_timeout_task:
stop_timeout_task.cancel()
return
elif task.get_name() == "ensemble_task":
stop_timeout_task = asyncio.create_task(
self._wait_for_stopped_server()
)
continue
else:
msg = (
f"Something went wrong, {task.get_name()} is done prematurely!"
)
logger.error(msg)
raise RuntimeError(msg)
)
logger.error(
(
f"Exception in evaluator task {task_name}: {task_exception}\n"
f"Traceback: {exc_traceback}"
)
)

async def run_and_get_successful_realizations(self) -> List[int]:
await self._start_running()
Expand Down

0 comments on commit 09a0a42

Please sign in to comment.