-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using keyword arguments in calculate_result
fails
#203
Comments
Hi @nbereux ! Thanks for the bug report. I think, the posted code is not quite using torchquad as intended. The calculate result function is typically used when you want to reuse the sample points. I tried to rewrite your example to use the code as described in the docs here: https://torchquad.readthedocs.io/en/main/tutorial.html#reusing-sample-points It looks like this then def integrand1(x):
return torch.rand(x.shape)
integration_domain = torch.Tensor([[0.0, 1.0]])
dim = 1
N = 101
integrator = Simpson()
grid_points, hs, n_per_dim = integrator.calculate_grid(N, integration_domain)
function_values, _ = integrator.evaluate_integrand(integrand1, grid_points)
integral1 = integrator.calculate_result(function_values, dim, n_per_dim, hs, integration_domain)
print(integral1) And this works for me as intended. I suspect there is an incompatibility because you computed the values yourself instead of using the built-in functions. Do you think the clarifies the issue? |
I don't think it clarifies the issue. I agree I don't use the code as intended, but reusing your code and putting keyword arguments instead of only arguments gives the same issue: def integrand1(x):
return torch.rand(x.shape)
integration_domain = torch.Tensor([[0.0, 1.0]])
dim = 1
N = 101
integrator = Simpson()
grid_points, hs, n_per_dim = integrator.calculate_grid(N, integration_domain)
function_values, _ = integrator.evaluate_integrand(integrand1, grid_points)
# Only the next line changed
integral1 = integrator.calculate_result(function_values=function_values, dim=dim, n_per_dim=n_per_dim, hs=hs, integration_domain=integration_domain)
print(integral1) |
a fix would be to do def expand_func_values_and_squeeze_integral(f):
"""This decorator ensures that the trailing dimension of integrands is indeed the integrand dimension.
This is pertinent in the 1d case when the sampled values are often of shape `(N,)`. Then, to maintain backward
consistency, we squeeze the result in the 1d case so it does not have any trailing dimensions.
Args:
f (Callable): the wrapped function
"""
def wrap(*args, **kwargs):
if len(args) > 1:
function_values = args[1]
elif 'function_values' in kwargs:
function_values = kwargs['function_values']
else:
raise ValueError("function_values argument not found in either positional or keyword arguments.")
# i.e we only have one dimension, or the second dimension (that of the integrand) is 1
is_1d = len(function_values.shape) == 1 or (
len(function_values.shape) == 2 and function_values.shape[1] == 1
)
if is_1d:
warnings.warn(
"DEPRECATION WARNING: In future versions of torchquad, an array-like object will be returned."
)
if len(args) > 1:
args = (args[0], anp.expand_dims(function_values, axis=1), *args[2:])
else:
kwargs['function_values'] = anp.expand_dims(function_values, axis=1)
result = f(*args, **kwargs)
return anp.squeeze(result)
return f(*args, **kwargs)
return wrap |
Issue
Problem Description
When using keyword arguments in
Simpson().calculate_result()
function, fails withIndexError: tuple index out of range
. The issue seems to come fromexpand_func_values_and_squeeze_integral
only checkingargs
and notkwargs
, so when specifying keyword arguments,args
is reduced to(self,)
and leads to the previous errorExpected Behavior
Should run properly as when not specifying keywords.
What Needs to be Done
Change
expand_func_values_and_squeeze_integral
to handle a both only keyword arguments and a mixture of positional and keyword argumentsHow Can It Be Tested or Reproduced
Here is a minimal example failing using torchquad v0.4.0
The text was updated successfully, but these errors were encountered: