From 29d4073ec98dd0add66f66d1c007db306ef50602 Mon Sep 17 00:00:00 2001 From: gpolazzo Date: Mon, 29 Jun 2026 18:33:06 -0400 Subject: [PATCH] Fix MATLAB wrapper virtual class downcasting for namespaced types --- gtwrap/matlab_wrapper/wrapper.py | 17 ++- tests/expected/matlab/Base.m | 53 +++++++++ tests/expected/matlab/Derived.m | 36 ++++++ tests/expected/matlab/inheritance_wrapper.cpp | 112 ++++++++++++++++++ tests/expected/python/inheritance_pybind.cpp | 5 + tests/fixtures/inheritance.i | 9 ++ tests/test_matlab_wrapper.py | 2 + 7 files changed, 231 insertions(+), 3 deletions(-) create mode 100644 tests/expected/matlab/Base.m create mode 100644 tests/expected/matlab/Derived.m diff --git a/gtwrap/matlab_wrapper/wrapper.py b/gtwrap/matlab_wrapper/wrapper.py index 6776320a..b7944b2f 100755 --- a/gtwrap/matlab_wrapper/wrapper.py +++ b/gtwrap/matlab_wrapper/wrapper.py @@ -1269,10 +1269,15 @@ def wrap_collector_function_return_types(self, return_type, func_id): return_type_text = self.wrap_collector_function_shared_return( return_type.typename, shared_obj, func_id, func_id == 0) else: - return_type_text += 'wrap_shared_ptr({0},"{1}", false);{new_line}' \ + is_virtual = any( + cls.name == return_type.typename.name and cls.is_virtual + for cls in self.classes + ) + return_type_text += 'wrap_shared_ptr({0},"{1}", {2});{new_line}' \ .format(shared_obj, self._format_type_name(return_type.typename, separator='.'), + 'true' if is_virtual else 'false', new_line=new_line) else: return_type_text += 'wrap< {0} >(pairResult.{1});{2}'.format( @@ -1339,8 +1344,13 @@ def _collector_return(self, obj=obj) if ctype.typename.name not in self.ignore_namespace: + is_virtual = any( + cls.name == ctype.typename.name and cls.is_virtual + for cls in self.classes + ) expanded += textwrap.indent( - 'out[0] = wrap_shared_ptr({0}, false);'.format(shared_obj), + 'out[0] = wrap_shared_ptr({0}, {1});'.format( + shared_obj, 'true' if is_virtual else 'false'), prefix=' ') else: expanded += ' out[0] = wrap< {0} >({1});'.format( @@ -1741,8 +1751,9 @@ def generate_preamble(self): if cls.is_virtual: class_name, class_name_sep = self.get_class_name(cls) + matlab_class_name = self._format_class_name(cls, separator='.') rtti_classes += ' types.insert(std::make_pair(typeid({}).name(), "{}"));\n' \ - .format(class_name_sep, class_name) + .format(class_name_sep, matlab_class_name) # Generate the typedef instances string typedef_instances = "\n".join(typedef_instances) diff --git a/tests/expected/matlab/Base.m b/tests/expected/matlab/Base.m new file mode 100644 index 00000000..0eae5590 --- /dev/null +++ b/tests/expected/matlab/Base.m @@ -0,0 +1,53 @@ +%class Base, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +%-------Static Methods------- +%Create(double x) : returns gtsam::Base +% +%-------Serialization Interface------- +%string_serialize() : returns string +%string_deserialize(string serialized) : returns Base +% +classdef Base < handle + properties + ptr_Base = 0 + end + methods + function obj = Base(varargin) + if (nargin == 2 || (nargin == 3 && strcmp(varargin{3}, 'void'))) && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + if nargin == 2 + my_ptr = varargin{2}; + else + my_ptr = inheritance_wrapper(58, varargin{2}); + end + inheritance_wrapper(57, my_ptr); + else + error('Arguments do not match any overload of Base constructor'); + end + obj.ptr_Base = my_ptr; + end + + function delete(obj) + inheritance_wrapper(59, obj.ptr_Base); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + function varargout = Create(varargin) + % CREATE usage: Create(double x) : returns gtsam.Base + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 1 && isa(varargin{1},'double') + varargout{1} = inheritance_wrapper(60, varargin{:}); + return + end + + error('Arguments do not match any overload of function Base.Create'); + end + + end +end diff --git a/tests/expected/matlab/Derived.m b/tests/expected/matlab/Derived.m new file mode 100644 index 00000000..bc9c9179 --- /dev/null +++ b/tests/expected/matlab/Derived.m @@ -0,0 +1,36 @@ +%class Derived, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +classdef Derived < Base + properties + ptr_Derived = 0 + end + methods + function obj = Derived(varargin) + if (nargin == 2 || (nargin == 3 && strcmp(varargin{3}, 'void'))) && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + if nargin == 2 + my_ptr = varargin{2}; + else + my_ptr = inheritance_wrapper(62, varargin{2}); + end + base_ptr = inheritance_wrapper(61, my_ptr); + else + error('Arguments do not match any overload of Derived constructor'); + end + obj = obj@Base(uint64(5139824614673773682), base_ptr); + obj.ptr_Derived = my_ptr; + end + + function delete(obj) + inheritance_wrapper(63, obj.ptr_Derived); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/tests/expected/matlab/inheritance_wrapper.cpp b/tests/expected/matlab/inheritance_wrapper.cpp index 5081ad86..24e166eb 100644 --- a/tests/expected/matlab/inheritance_wrapper.cpp +++ b/tests/expected/matlab/inheritance_wrapper.cpp @@ -20,6 +20,10 @@ typedef std::set*> Collector_ForwardKin static Collector_ForwardKinematicsFactor collector_ForwardKinematicsFactor; typedef std::set*> Collector_ParentHasTemplateDouble; static Collector_ParentHasTemplateDouble collector_ParentHasTemplateDouble; +typedef std::set*> Collector_Base; +static Collector_Base collector_Base; +typedef std::set*> Collector_Derived; +static Collector_Derived collector_Derived; void _deleteAllObjects() @@ -64,6 +68,18 @@ void _deleteAllObjects() collector_ParentHasTemplateDouble.erase(iter++); anyDeleted = true; } } + { for(Collector_Base::iterator iter = collector_Base.begin(); + iter != collector_Base.end(); ) { + delete *iter; + collector_Base.erase(iter++); + anyDeleted = true; + } } + { for(Collector_Derived::iterator iter = collector_Derived.begin(); + iter != collector_Derived.end(); ) { + delete *iter; + collector_Derived.erase(iter++); + anyDeleted = true; + } } if(anyDeleted) cout << @@ -84,6 +100,8 @@ void _inheritance_RTTIRegister() { types.insert(std::make_pair(typeid(MyTemplateA).name(), "MyTemplateA")); types.insert(std::make_pair(typeid(ForwardKinematicsFactor).name(), "ForwardKinematicsFactor")); types.insert(std::make_pair(typeid(ParentHasTemplateDouble).name(), "ParentHasTemplateDouble")); + types.insert(std::make_pair(typeid(Base).name(), "Base")); + types.insert(std::make_pair(typeid(Derived).name(), "Derived")); mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry"); @@ -698,6 +716,79 @@ void ParentHasTemplateDouble_deconstructor_56(int nargout, mxArray *out[], int n delete self; } +void Base_collectorInsertAndMakeBase_57(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr Shared; + + Shared *self = *reinterpret_cast (mxGetData(in[0])); + collector_Base.insert(self); +} + +void Base_upcastFromVoid_58(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr Shared; + std::shared_ptr *asVoid = *reinterpret_cast**> (mxGetData(in[0])); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + Shared *self = new Shared(std::static_pointer_cast(*asVoid)); + *reinterpret_cast(mxGetData(out[0])) = self; +} + +void Base_deconstructor_59(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + typedef std::shared_ptr Shared; + checkArguments("delete_Base",nargout,nargin,1); + Shared *self = *reinterpret_cast(mxGetData(in[0])); + Collector_Base::iterator item; + item = collector_Base.find(self); + if(item != collector_Base.end()) { + collector_Base.erase(item); + } + delete self; +} + +void Base_Create_60(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("Base.Create",nargout,nargin,1); + double x = unwrap< double >(in[0]); + out[0] = wrap_shared_ptr(Base::Create(x),"gtsam.Base", true); +} + +void Derived_collectorInsertAndMakeBase_61(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr Shared; + + Shared *self = *reinterpret_cast (mxGetData(in[0])); + collector_Derived.insert(self); + + typedef std::shared_ptr SharedBase; + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + *reinterpret_cast(mxGetData(out[0])) = new SharedBase(*self); +} + +void Derived_upcastFromVoid_62(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr Shared; + std::shared_ptr *asVoid = *reinterpret_cast**> (mxGetData(in[0])); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + Shared *self = new Shared(std::static_pointer_cast(*asVoid)); + *reinterpret_cast(mxGetData(out[0])) = self; +} + +void Derived_deconstructor_63(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + typedef std::shared_ptr Shared; + checkArguments("delete_Derived",nargout,nargin,1); + Shared *self = *reinterpret_cast(mxGetData(in[0])); + Collector_Derived::iterator item; + item = collector_Derived.find(self); + if(item != collector_Derived.end()) { + collector_Derived.erase(item); + } + delete self; +} + void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { @@ -881,6 +972,27 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) case 56: ParentHasTemplateDouble_deconstructor_56(nargout, out, nargin-1, in+1); break; + case 57: + Base_collectorInsertAndMakeBase_57(nargout, out, nargin-1, in+1); + break; + case 58: + Base_upcastFromVoid_58(nargout, out, nargin-1, in+1); + break; + case 59: + Base_deconstructor_59(nargout, out, nargin-1, in+1); + break; + case 60: + Base_Create_60(nargout, out, nargin-1, in+1); + break; + case 61: + Derived_collectorInsertAndMakeBase_61(nargout, out, nargin-1, in+1); + break; + case 62: + Derived_upcastFromVoid_62(nargout, out, nargin-1, in+1); + break; + case 63: + Derived_deconstructor_63(nargout, out, nargin-1, in+1); + break; } } catch(const std::exception& e) { mexErrMsgTxt(("Exception from gtsam:\n" + std::string(e.what()) + "\n").c_str()); diff --git a/tests/expected/python/inheritance_pybind.cpp b/tests/expected/python/inheritance_pybind.cpp index 39729f28..29c1cc22 100644 --- a/tests/expected/python/inheritance_pybind.cpp +++ b/tests/expected/python/inheritance_pybind.cpp @@ -85,6 +85,11 @@ PYBIND11_MODULE(inheritance_py, m_) { py::class_, MyTemplate, std::shared_ptr>>(m_, "ParentHasTemplateDouble"); + py::class_>(m_, "Base") + .def_static("Create",[](double x){return Base::Create(x);}, gtwrap::internal::py_arg("x")); + + py::class_>(m_, "Derived"); + #include "python/specializations.h" diff --git a/tests/fixtures/inheritance.i b/tests/fixtures/inheritance.i index 6f220f2a..8da210ad 100644 --- a/tests/fixtures/inheritance.i +++ b/tests/fixtures/inheritance.i @@ -27,3 +27,12 @@ virtual class ForwardKinematicsFactor : gtsam::BetweenFactor {}; template virtual class ParentHasTemplate : MyTemplate {}; + +// A base class with a smart static factory that may downcast +virtual class Base { + static gtsam::Base* Create(double x); +}; + +// A derived class returned by Base::Create when appropriate +virtual class Derived : Base { +}; diff --git a/tests/test_matlab_wrapper.py b/tests/test_matlab_wrapper.py index 3b98442b..522955eb 100644 --- a/tests/test_matlab_wrapper.py +++ b/tests/test_matlab_wrapper.py @@ -256,6 +256,8 @@ def test_inheritance(self): 'MyTemplatePoint2.m', 'ForwardKinematicsFactor.m', 'ParentHasTemplateDouble.m', + 'Base.m', + 'Derived.m', ] for file in files: