torchmetrics Accuracy() fails if get_metrics()
is called before test_on_dataset
#272
Labels
bug
Something isn't working
get_metrics()
is called before test_on_dataset
#272
Describe the bug
The torchmetrics
Accuracy()
class returns an errorRuntimeError: You have to have determined mode.
ifwrapper.get_metrics()
is called beforewrapper.test_on_dataset
, or ifwrapper.test_on_dataset
is not called at all.In contrast, Baal's
Accuracy()
class handles this by returning'test_accuracy': nan
.To Reproduce
In this gist, a LeNet-5 model with MC Dropout is trained on the entire MNIST data for 1 epoch (no acquisitions).
The script uses Baal's
Accuracy()
class as standard, and adds torchmetricsAccuracy()
class with the option--torchmetrics
. The script evaluates on the test set as standard, and omits this with the option--no-test
.Running
python baal_error_torchmetrics.py --no-test
:Running
python baal_error_torchmetrics.py --torchmetrics
:Running
python baal_error_torchmetrics.py --torchmetrics --no-test
:Expected behavior
The torchmetrics
Accuracy()
class should also return'test_torch_accuracy': nan
, just like Baal'sAccuracy()
class.Version:
Additional context
/
The text was updated successfully, but these errors were encountered: