Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions gtwrap/matlab_wrapper/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,10 +1269,15 @@ def wrap_collector_function_return_types(self, return_type, func_id):
return_type_text = self.wrap_collector_function_shared_return(
return_type.typename, shared_obj, func_id, func_id == 0)
else:
return_type_text += 'wrap_shared_ptr({0},"{1}", false);{new_line}' \
is_virtual = any(
cls.name == return_type.typename.name and cls.is_virtual
for cls in self.classes
)
return_type_text += 'wrap_shared_ptr({0},"{1}", {2});{new_line}' \
.format(shared_obj,
self._format_type_name(return_type.typename,
separator='.'),
'true' if is_virtual else 'false',
new_line=new_line)
else:
return_type_text += 'wrap< {0} >(pairResult.{1});{2}'.format(
Expand Down Expand Up @@ -1339,8 +1344,13 @@ def _collector_return(self,
obj=obj)

if ctype.typename.name not in self.ignore_namespace:
is_virtual = any(
cls.name == ctype.typename.name and cls.is_virtual
for cls in self.classes
)
expanded += textwrap.indent(
'out[0] = wrap_shared_ptr({0}, false);'.format(shared_obj),
'out[0] = wrap_shared_ptr({0}, {1});'.format(
shared_obj, 'true' if is_virtual else 'false'),
prefix=' ')
else:
expanded += ' out[0] = wrap< {0} >({1});'.format(
Expand Down Expand Up @@ -1741,8 +1751,9 @@ def generate_preamble(self):

if cls.is_virtual:
class_name, class_name_sep = self.get_class_name(cls)
matlab_class_name = self._format_class_name(cls, separator='.')
rtti_classes += ' types.insert(std::make_pair(typeid({}).name(), "{}"));\n' \
.format(class_name_sep, class_name)
.format(class_name_sep, matlab_class_name)

# Generate the typedef instances string
typedef_instances = "\n".join(typedef_instances)
Expand Down
53 changes: 53 additions & 0 deletions tests/expected/matlab/Base.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
%class Base, see Doxygen page for details
%at https://gtsam.org/doxygen/
%
%-------Static Methods-------
%Create(double x) : returns gtsam::Base
%
%-------Serialization Interface-------
%string_serialize() : returns string
%string_deserialize(string serialized) : returns Base
%
classdef Base < handle
properties
ptr_Base = 0
end
methods
function obj = Base(varargin)
if (nargin == 2 || (nargin == 3 && strcmp(varargin{3}, 'void'))) && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682)
if nargin == 2
my_ptr = varargin{2};
else
my_ptr = inheritance_wrapper(58, varargin{2});
end
inheritance_wrapper(57, my_ptr);
else
error('Arguments do not match any overload of Base constructor');
end
obj.ptr_Base = my_ptr;
end

function delete(obj)
inheritance_wrapper(59, obj.ptr_Base);
end

function display(obj), obj.print(''); end
%DISPLAY Calls print on the object
function disp(obj), obj.display; end
%DISP Calls print on the object
end

methods(Static = true)
function varargout = Create(varargin)
% CREATE usage: Create(double x) : returns gtsam.Base
% Doxygen can be found at https://gtsam.org/doxygen/
if length(varargin) == 1 && isa(varargin{1},'double')
varargout{1} = inheritance_wrapper(60, varargin{:});
return
end

error('Arguments do not match any overload of function Base.Create');
end

end
end
36 changes: 36 additions & 0 deletions tests/expected/matlab/Derived.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
%class Derived, see Doxygen page for details
%at https://gtsam.org/doxygen/
%
classdef Derived < Base
properties
ptr_Derived = 0
end
methods
function obj = Derived(varargin)
if (nargin == 2 || (nargin == 3 && strcmp(varargin{3}, 'void'))) && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682)
if nargin == 2
my_ptr = varargin{2};
else
my_ptr = inheritance_wrapper(62, varargin{2});
end
base_ptr = inheritance_wrapper(61, my_ptr);
else
error('Arguments do not match any overload of Derived constructor');
end
obj = obj@Base(uint64(5139824614673773682), base_ptr);
obj.ptr_Derived = my_ptr;
end

function delete(obj)
inheritance_wrapper(63, obj.ptr_Derived);
end

function display(obj), obj.print(''); end
%DISPLAY Calls print on the object
function disp(obj), obj.display; end
%DISP Calls print on the object
end

methods(Static = true)
end
end
112 changes: 112 additions & 0 deletions tests/expected/matlab/inheritance_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ typedef std::set<std::shared_ptr<ForwardKinematicsFactor>*> Collector_ForwardKin
static Collector_ForwardKinematicsFactor collector_ForwardKinematicsFactor;
typedef std::set<std::shared_ptr<ParentHasTemplateDouble>*> Collector_ParentHasTemplateDouble;
static Collector_ParentHasTemplateDouble collector_ParentHasTemplateDouble;
typedef std::set<std::shared_ptr<Base>*> Collector_Base;
static Collector_Base collector_Base;
typedef std::set<std::shared_ptr<Derived>*> Collector_Derived;
static Collector_Derived collector_Derived;


void _deleteAllObjects()
Expand Down Expand Up @@ -64,6 +68,18 @@ void _deleteAllObjects()
collector_ParentHasTemplateDouble.erase(iter++);
anyDeleted = true;
} }
{ for(Collector_Base::iterator iter = collector_Base.begin();
iter != collector_Base.end(); ) {
delete *iter;
collector_Base.erase(iter++);
anyDeleted = true;
} }
{ for(Collector_Derived::iterator iter = collector_Derived.begin();
iter != collector_Derived.end(); ) {
delete *iter;
collector_Derived.erase(iter++);
anyDeleted = true;
} }

if(anyDeleted)
cout <<
Expand All @@ -84,6 +100,8 @@ void _inheritance_RTTIRegister() {
types.insert(std::make_pair(typeid(MyTemplateA).name(), "MyTemplateA"));
types.insert(std::make_pair(typeid(ForwardKinematicsFactor).name(), "ForwardKinematicsFactor"));
types.insert(std::make_pair(typeid(ParentHasTemplateDouble).name(), "ParentHasTemplateDouble"));
types.insert(std::make_pair(typeid(Base).name(), "Base"));
types.insert(std::make_pair(typeid(Derived).name(), "Derived"));


mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry");
Expand Down Expand Up @@ -698,6 +716,79 @@ void ParentHasTemplateDouble_deconstructor_56(int nargout, mxArray *out[], int n
delete self;
}

void Base_collectorInsertAndMakeBase_57(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef std::shared_ptr<Base> Shared;

Shared *self = *reinterpret_cast<Shared**> (mxGetData(in[0]));
collector_Base.insert(self);
}

void Base_upcastFromVoid_58(int nargout, mxArray *out[], int nargin, const mxArray *in[]) {
mexAtExit(&_deleteAllObjects);
typedef std::shared_ptr<Base> Shared;
std::shared_ptr<void> *asVoid = *reinterpret_cast<std::shared_ptr<void>**> (mxGetData(in[0]));
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
Shared *self = new Shared(std::static_pointer_cast<Base>(*asVoid));
*reinterpret_cast<Shared**>(mxGetData(out[0])) = self;
}

void Base_deconstructor_59(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
typedef std::shared_ptr<Base> Shared;
checkArguments("delete_Base",nargout,nargin,1);
Shared *self = *reinterpret_cast<Shared**>(mxGetData(in[0]));
Collector_Base::iterator item;
item = collector_Base.find(self);
if(item != collector_Base.end()) {
collector_Base.erase(item);
}
delete self;
}

void Base_Create_60(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("Base.Create",nargout,nargin,1);
double x = unwrap< double >(in[0]);
out[0] = wrap_shared_ptr(Base::Create(x),"gtsam.Base", true);
}

void Derived_collectorInsertAndMakeBase_61(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef std::shared_ptr<Derived> Shared;

Shared *self = *reinterpret_cast<Shared**> (mxGetData(in[0]));
collector_Derived.insert(self);

typedef std::shared_ptr<Base> SharedBase;
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
*reinterpret_cast<SharedBase**>(mxGetData(out[0])) = new SharedBase(*self);
}

void Derived_upcastFromVoid_62(int nargout, mxArray *out[], int nargin, const mxArray *in[]) {
mexAtExit(&_deleteAllObjects);
typedef std::shared_ptr<Derived> Shared;
std::shared_ptr<void> *asVoid = *reinterpret_cast<std::shared_ptr<void>**> (mxGetData(in[0]));
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
Shared *self = new Shared(std::static_pointer_cast<Derived>(*asVoid));
*reinterpret_cast<Shared**>(mxGetData(out[0])) = self;
}

void Derived_deconstructor_63(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
typedef std::shared_ptr<Derived> Shared;
checkArguments("delete_Derived",nargout,nargin,1);
Shared *self = *reinterpret_cast<Shared**>(mxGetData(in[0]));
Collector_Derived::iterator item;
item = collector_Derived.find(self);
if(item != collector_Derived.end()) {
collector_Derived.erase(item);
}
delete self;
}


void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
Expand Down Expand Up @@ -881,6 +972,27 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[])
case 56:
ParentHasTemplateDouble_deconstructor_56(nargout, out, nargin-1, in+1);
break;
case 57:
Base_collectorInsertAndMakeBase_57(nargout, out, nargin-1, in+1);
break;
case 58:
Base_upcastFromVoid_58(nargout, out, nargin-1, in+1);
break;
case 59:
Base_deconstructor_59(nargout, out, nargin-1, in+1);
break;
case 60:
Base_Create_60(nargout, out, nargin-1, in+1);
break;
case 61:
Derived_collectorInsertAndMakeBase_61(nargout, out, nargin-1, in+1);
break;
case 62:
Derived_upcastFromVoid_62(nargout, out, nargin-1, in+1);
break;
case 63:
Derived_deconstructor_63(nargout, out, nargin-1, in+1);
break;
}
} catch(const std::exception& e) {
mexErrMsgTxt(("Exception from gtsam:\n" + std::string(e.what()) + "\n").c_str());
Expand Down
5 changes: 5 additions & 0 deletions tests/expected/python/inheritance_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ PYBIND11_MODULE(inheritance_py, m_) {

py::class_<ParentHasTemplate<double>, MyTemplate<double>, std::shared_ptr<ParentHasTemplate<double>>>(m_, "ParentHasTemplateDouble");

py::class_<Base, std::shared_ptr<Base>>(m_, "Base")
.def_static("Create",[](double x){return Base::Create(x);}, gtwrap::internal::py_arg<double>("x"));

py::class_<Derived, Base, std::shared_ptr<Derived>>(m_, "Derived");


#include "python/specializations.h"

Expand Down
9 changes: 9 additions & 0 deletions tests/fixtures/inheritance.i
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,12 @@ virtual class ForwardKinematicsFactor : gtsam::BetweenFactor<gtsam::Pose3> {};

template <T = {double}>
virtual class ParentHasTemplate : MyTemplate<T> {};

// A base class with a smart static factory that may downcast
virtual class Base {
static gtsam::Base* Create(double x);
};

// A derived class returned by Base::Create when appropriate
virtual class Derived : Base {
};
2 changes: 2 additions & 0 deletions tests/test_matlab_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def test_inheritance(self):
'MyTemplatePoint2.m',
'ForwardKinematicsFactor.m',
'ParentHasTemplateDouble.m',
'Base.m',
'Derived.m',
]

for file in files:
Expand Down
Loading