Skip to content

Commit

Permalink
Basic conversion of parameters for global wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgoss committed Sep 25, 2023
1 parent 8901513 commit fc311f7
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 29 deletions.
193 changes: 192 additions & 1 deletion Source/buildimplementationrust.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,38 @@ func BuildImplementationRust(component ComponentDefinition, outputFolder string,
return err
}

IntfWrapperFileName := BaseName + "_interface_wrapper.rs"
IntfWrapperFilePath := path.Join(outputFolder, IntfWrapperFileName)
modfiles = append(modfiles, IntfWrapperFilePath)
log.Printf("Creating \"%s\"", IntfWrapperFilePath)
IntfWrapperRSFile, err := CreateLanguageFile(IntfWrapperFilePath, indentString)
if err != nil {
return err
}
IntfWrapperRSFile.WriteCLicenseHeader(component,
fmt.Sprintf("This is an autogenerated Rust implementation file in order to allow easy\ndevelopment of %s. The functions in this file need to be implemented. It needs to be generated only once.", LibraryName),
true)
err = buildRustWrapper(component, IntfWrapperRSFile, InterfaceMod)
if err != nil {
return err
}

IntfHandleFileName := BaseName + "_interface_handle.rs"
IntfHandleFilePath := path.Join(outputFolder, IntfHandleFileName)
modfiles = append(modfiles, IntfHandleFilePath)
log.Printf("Creating \"%s\"", IntfHandleFilePath)
IntfHandleRSFile, err := CreateLanguageFile(IntfHandleFilePath, indentString)
if err != nil {
return err
}
IntfHandleRSFile.WriteCLicenseHeader(component,
fmt.Sprintf("This is an autogenerated Rust implementation file in order to allow easy\ndevelopment of %s. The functions in this file need to be implemented. It needs to be generated only once.", LibraryName),
true)
err = buildRustHandle(component, IntfHandleRSFile, InterfaceMod)
if err != nil {
return err
}

IntfWrapperStubName := path.Join(stubOutputFolder, BaseName+stubIdentifier+".rs")
modfiles = append(modfiles, IntfWrapperStubName)
if forceRebuild || !FileExists(IntfWrapperStubName) {
Expand Down Expand Up @@ -344,7 +376,7 @@ func buildRustGlobalStubFile(component ComponentDefinition, w LanguageWriter, In
w.Writeln("use %s::*;", InterfaceMod)
w.Writeln("")
w.Writeln("// Wrapper struct to implement the wrapper trait for global methods")
w.Writeln("struct CWrapper;")
w.Writeln("pub struct CWrapper;")
w.Writeln("")
w.Writeln("impl Wrapper for CWrapper {")
w.Writeln("")
Expand Down Expand Up @@ -474,3 +506,162 @@ func buildRustStubFile(component ComponentDefinition, class ComponentDefinitionC
w.Writeln("")
return nil
}

func buildRustWrapper(component ComponentDefinition, w LanguageWriter, InterfaceMod string) error {
// Imports
ModName := strings.ToLower(component.NameSpace)
w.Writeln("")
w.Writeln("// Calls from the C-Interface to the Rust traits via the CWrapper")
w.Writeln("// These are the symbols exposed in the shared object interface")
w.Writeln("")
w.Writeln("use %s::*;", InterfaceMod)
w.Writeln("use %s::CWrapper;", ModName)
w.Writeln("use std::ffi::{c_char, CStr};")
w.Writeln("")
cprefix := ModName + "_"
// Build the global methods
err := writeGlobalRustWrapper(component, w, cprefix)
if err != nil {
return err
}
return nil
}

func buildRustHandle(component ComponentDefinition, w LanguageWriter, InterfaceMod string) error {
w.Writeln("")
w.Writeln("// Handle passed through interface define the casting maps needed to extract")
w.Writeln("")
w.Writeln("use %s::*;", InterfaceMod)
w.Writeln("")
w.Writeln("impl HandleImpl {")
w.AddIndentationLevel(1)
for i := 0; i < len(component.Classes); i++ {
class := component.Classes[i]
writeRustHandleAs(component, w, class, false)
writeRustHandleAs(component, w, class, true)
w.Writeln("")
}
w.AddIndentationLevel(-1)
w.Writeln("}")
return nil
}

func writeRustHandleAs(component ComponentDefinition, w LanguageWriter, class ComponentDefinitionClass, mut bool) error {
//parents, err := getParentList(component, class)
//if err != nil {
// return err
//}
Name := class.ClassName
if !mut {
w.Writeln("pub fn as_%s(&self) -> Option<&dyn %s> {", toSnakeCase(Name), Name)
} else {
w.Writeln("pub fn as_mut_%s(&mut self) -> Option<&mut dyn %s> {", toSnakeCase(Name), Name)
}
w.AddIndentationLevel(1)
w.Writeln("None")
w.AddIndentationLevel(-1)
w.Writeln("}")
return nil
}

func writeGlobalRustWrapper(component ComponentDefinition, w LanguageWriter, cprefix string) error {
methods := component.Global.Methods
for i := 0; i < len(methods); i++ {
method := methods[i]
err := writeRustMethodWrapper(method, w, cprefix)
if err != nil {
return err
}
w.Writeln("")
}
return nil
}

func writeRustMethodWrapper(method ComponentDefinitionMethod, w LanguageWriter, cprefix string) error {
// Build up the parameter strings
parameterString := ""
returnName := ""
for k := 0; k < len(method.Params); k++ {
param := method.Params[k]
RustParams, err := generateRustParameters(param, true)
if err != nil {
return err
}
for i := 0; i < len(RustParams); i++ {
RustParam := RustParams[i]
if parameterString == "" {
parameterString += fmt.Sprintf("%s : %s", RustParam.ParamName, RustParam.ParamType)
} else {
parameterString += fmt.Sprintf(", %s : %s", RustParam.ParamName, RustParam.ParamType)
}
}
}
w.Writeln("pub fn %s%s(%s) -> i32 {", cprefix, strings.ToLower(method.MethodName), parameterString)
w.AddIndentationLevel(1)
argsString := ""
for k := 0; k < len(method.Params); k++ {
param := method.Params[k]
OName, err := writeRustParameterConversionArg(param, w)
if err != nil {
return err
}
if OName != "" {
if argsString == "" {
argsString = OName
} else {
argsString += fmt.Sprintf(", %s", OName)
}
}
}
if returnName != "" {
w.Writeln("let %s = CWrapper::%s(%s);", returnName, toSnakeCase(method.MethodName), argsString)
} else {
w.Writeln("CWrapper::%s(%s);", toSnakeCase(method.MethodName), argsString)
}
w.Writeln("// All ok")
w.Writeln("0")
w.AddIndentationLevel(-1)
w.Writeln("}")
return nil
}

func writeRustParameterConversionArg(param ComponentDefinitionParam, w LanguageWriter) (string, error) {
if param.ParamPass == "return" {
return "", nil
}
IName := toSnakeCase(param.ParamName)
OName := "_" + IName
switch param.ParamType {
case "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "single", "double":
if param.ParamPass == "in" {
w.Writeln("let %s = %s;", OName, IName)
} else {
w.Writeln("let %s = unsafe {&mut *%s};", OName, IName)
}
case "class", "optionalclass":
if param.ParamPass == "in" {
HName := "_Handle_" + IName
w.Writeln("let %s = unsafe {&*%s};", HName, IName)
w.Writeln("let %s = %s.as_%s().unwrap();", OName, HName, toSnakeCase(param.ParamClass))
} else {
HName := "_Handle_" + IName
w.Writeln("let %s = unsafe {&mut *%s};", HName, IName)
w.Writeln("let %s = %s.as_mut_%s().unwrap();", OName, HName, toSnakeCase(param.ParamClass))
}
case "string":
if param.ParamPass == "in" {
SName := "_Str_" + IName
w.Writeln("let %s = unsafe{ CStr::from_ptr(%s) };", SName, IName)
w.Writeln("let %s = %s.to_str().unwrap();", OName, SName)
} else {
SName := "_String_" + IName
w.Writeln("let mut %s = String::new();", SName)
w.Writeln("let %s = &mut %s;", OName, SName)
}
case "bool", "pointer", "struct", "basicarray", "structarray":
//return fmt.Errorf("Conversion of type %s for parameter %s not supported", param.ParamType, IName)
default:
return "", fmt.Errorf("Conversion of type %s for parameter %s not supported as is unknown", param.ParamType, IName)
}
return OName, nil
}
100 changes: 72 additions & 28 deletions Source/languagerust.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func toSnakeCase(BaseType string) string {

func writeRustBaseTypeDefinitions(componentdefinition ComponentDefinition, w LanguageWriter, NameSpace string, BaseName string) error {
w.Writeln("#[allow(unused_imports)]")
w.Writeln("use std::ffi;")
w.Writeln("use std::ffi::c_void;")
w.Writeln("")
w.Writeln("/*************************************************************************************************************************")
w.Writeln(" Version definition for %s", NameSpace)
Expand All @@ -63,10 +63,28 @@ func writeRustBaseTypeDefinitions(componentdefinition ComponentDefinition, w Lan
w.Writeln("")

w.Writeln("/*************************************************************************************************************************")
w.Writeln(" Basic pointers definition for %s", NameSpace)
w.Writeln(" Handle definiton for %s", NameSpace)
w.Writeln("**************************************************************************************************************************/")
w.Writeln("")
w.Writeln("type Handle = std::ffi::c_void;")
w.Writeln("// Enum of all traits - this acts as a handle as we pass trait pointers through the interface")
w.Writeln("pub enum HandleImpl {")
w.AddIndentationLevel(1)
for i := 0; i < len(componentdefinition.Classes); i++ {
class := componentdefinition.Classes[i]
if i != len(componentdefinition.Classes)-1 {
w.Writeln("T%s(Box<dyn %s>),", class.ClassName, class.ClassName)
} else {
w.Writeln("T%s(Box<dyn %s>)", class.ClassName, class.ClassName)
}
}
w.AddIndentationLevel(-1)
w.Writeln("}")
w.Writeln("")
w.Writeln("pub type Handle = *mut HandleImpl;")
for i := 0; i < len(componentdefinition.Classes); i++ {
class := componentdefinition.Classes[i]
w.Writeln("pub type %sHandle = *mut HandleImpl;", class.ClassName)
}

if len(componentdefinition.Enums) > 0 {
w.Writeln("/*************************************************************************************************************************")
Expand Down Expand Up @@ -211,6 +229,25 @@ func generateRustParameters(param ComponentDefinitionParam, isPlain bool) ([]Rus
}

if isPlain {
if param.ParamType == "string" {
if param.ParamPass == "out" {
Params = make([]RustParameter, 3)
Params[0].ParamType = "u64"
Params[0].ParamName = toSnakeCase(param.ParamName) + "_buffer_size"
Params[0].ParamComment = fmt.Sprintf("* @param[in] %s - size of the buffer (including trailing 0)", Params[0].ParamName)

Params[1].ParamType = "*mut u64"
Params[1].ParamName = toSnakeCase(param.ParamName) + "_needed_chars"
Params[1].ParamComment = fmt.Sprintf("* @param[out] %s - will be filled with the count of the written bytes, or needed buffer size.", Params[1].ParamName)

Params[2].ParamType = "*mut c_char"
Params[2].ParamName = toSnakeCase(param.ParamName) + "_buffer"
Params[2].ParamComment = fmt.Sprintf("* @param[out] %s - %s buffer of %s, may be NULL", Params[2].ParamName, param.ParamClass, param.ParamDescription)

return Params, nil
}
}

if param.ParamType == "basicarray" {
return nil, fmt.Errorf("Not yet handled")
}
Expand All @@ -231,50 +268,51 @@ func generateRustParameterType(param ComponentDefinitionParam, isPlain bool) (st
RustParamTypeName := ""
ParamTypeName := param.ParamType
ParamClass := param.ParamClass
BasicType := false
switch ParamTypeName {
case "uint8":
RustParamTypeName = "u8"

BasicType = true
case "uint16":
RustParamTypeName = "u16"

BasicType = true
case "uint32":
RustParamTypeName = "u32"

BasicType = true
case "uint64":
RustParamTypeName = "u64"

BasicType = true
case "int8":
RustParamTypeName = "i8"

BasicType = true
case "int16":
RustParamTypeName = "i16"

BasicType = true
case "int32":
RustParamTypeName = "i32"

BasicType = true
case "int64":
RustParamTypeName = "i64"

BasicType = true
case "bool":
if isPlain {
RustParamTypeName = "u8"
} else {
RustParamTypeName = "bool"
}

BasicType = true
case "single":
RustParamTypeName = "f32"

BasicType = true
case "double":
RustParamTypeName = "f64"

BasicType = true
case "pointer":
RustParamTypeName = "c_void"

BasicType = true
case "string":
if isPlain {
RustParamTypeName = "*mut char"
RustParamTypeName = "*const c_char"
} else {
switch param.ParamPass {
case "out":
Expand All @@ -290,17 +328,12 @@ func generateRustParameterType(param ComponentDefinitionParam, isPlain bool) (st
if isPlain {
RustParamTypeName = fmt.Sprintf("u16")
} else {
switch param.ParamPass {
case "out":
RustParamTypeName = fmt.Sprintf("&mut %s", ParamClass)
case "in", "return":
RustParamTypeName = fmt.Sprintf("%s", ParamClass)
}
RustParamTypeName = ParamClass
}

BasicType = true
case "functiontype":
RustParamTypeName = fmt.Sprintf("%s", ParamClass)

BasicType = true
case "struct":
if isPlain {
RustParamTypeName = fmt.Sprintf("*mut %s", ParamClass)
Expand Down Expand Up @@ -353,13 +386,14 @@ func generateRustParameterType(param ComponentDefinitionParam, isPlain bool) (st

case "class", "optionalclass":
if isPlain {
RustParamTypeName = fmt.Sprintf("Handle")
RustParamTypeName = fmt.Sprintf("%sHandle", ParamClass)
BasicType = true
} else {
switch param.ParamPass {
case "out":
RustParamTypeName = fmt.Sprintf("&mut impl %s", ParamClass)
RustParamTypeName = fmt.Sprintf("&mut dyn %s", ParamClass)
case "in":
RustParamTypeName = fmt.Sprintf("& impl %s", ParamClass)
RustParamTypeName = fmt.Sprintf("& dyn %s", ParamClass)
case "return":
RustParamTypeName = fmt.Sprintf("Box<dyn %s>", ParamClass)
}
Expand All @@ -368,6 +402,16 @@ func generateRustParameterType(param ComponentDefinitionParam, isPlain bool) (st
default:
return "", fmt.Errorf("invalid parameter type \"%s\" for Rust parameter", ParamTypeName)
}

if BasicType {
if param.ParamPass == "out" {
if isPlain {
RustParamOutTypeName := fmt.Sprintf("*mut %s", RustParamTypeName)
return RustParamOutTypeName, nil
} else {
RustParamOutTypeName := fmt.Sprintf("&mut %s", RustParamTypeName)
return RustParamOutTypeName, nil
}
}
}
return RustParamTypeName, nil
}

0 comments on commit fc311f7

Please sign in to comment.