Skip to content

Commit fdbfd0f

Browse files
authored
Merge pull request #2491 from stan-dev/fix/conditional_var_value
Have conditional_var_value_t return correct matrix type if scalar is var
2 parents 29843fe + 0034da2 commit fdbfd0f

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

stan/math/rev/meta/conditional_var_value.hpp

+7-3
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,23 @@
33

44
#include <stan/math/rev/core/var.hpp>
55
#include <stan/math/rev/meta/plain_type.hpp>
6+
#include <stan/math/prim/meta/promote_scalar_type.hpp>
67

78
namespace stan {
89

910
/**
10-
* Constructs a prim type or var_value from a scalar and a container.
11+
* Conditionally construct a var_value container based on a scalar type. For
12+
* var types as the scalar, the `var_value<Matrix>`'s inner type will have a
13+
* scalar of a double.
1114
* @tparam T_scalar Determines the scalar (var/double) of the type.
1215
* @tparam T_container Determines the container (matrix/vector/matrix_cl ...) of
13-
* the type. This must be a prim type.
16+
* the type.
1417
*/
1518
template <typename T_scalar, typename T_container, typename = void>
1619
struct conditional_var_value {
1720
using type = std::conditional_t<is_var<scalar_type_t<T_scalar>>::value,
18-
math::var_value<plain_type_t<T_container>>,
21+
math::var_value<math::promote_scalar_t<
22+
double, plain_type_t<T_container>>>,
1923
plain_type_t<T_container>>;
2024
};
2125
template <typename T_scalar, typename T_container>

test/unit/math/rev/meta/conditional_var_value_test.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ TEST(MathMetaRev, conditional_var_value_matrix) {
4444
conditional_var_value_t<double, Eigen::MatrixXd>);
4545
EXPECT_SAME_TYPE(var_value<Eigen::MatrixXd>,
4646
conditional_var_value_t<var, Eigen::MatrixXd>);
47+
EXPECT_SAME_TYPE(var_value<Eigen::MatrixXd>,
48+
conditional_var_value_t<var, Eigen::Matrix<var, -1, -1>>);
4749
}
4850

4951
TEST(MathMetaRev, conditional_var_value_expression) {

0 commit comments

Comments
 (0)