diff --git a/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs b/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs index 61e71adcff477a..acb629e66e3662 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs @@ -11,6 +11,7 @@ using System.Runtime.InteropServices.Marshalling; using System.Runtime.Versioning; using System.Text; +using System.Threading; namespace System.StubHelpers { @@ -229,6 +230,39 @@ internal static unsafe void ConvertToManaged(StringBuilder sb, IntPtr pNative) internal static class BSTRMarshaler { + private sealed class TrailByte(byte trailByte) + { + public readonly byte Value = trailByte; + } + + // In some early version of VB when there were no arrays developers used to use BSTR as arrays + // The way this was done was by adding a trail byte at the end of the BSTR + // To support this scenario, we need to use a ConditionalWeakTable for this special case and + // save the trail character in here. + // This stores the trail character when a BSTR is used as an array. + private static ConditionalWeakTable? s_trailByteTable; + + private static bool TryGetTrailByte(string strManaged, out byte trailByte) + { + if (s_trailByteTable?.TryGetValue(strManaged, out TrailByte? trailByteObj) == true) + { + trailByte = trailByteObj.Value; + return true; + } + + trailByte = 0; + return false; + } + + private static void SetTrailByte(string strManaged, byte trailByte) + { + if (s_trailByteTable == null) + { + Interlocked.CompareExchange(ref s_trailByteTable, new ConditionalWeakTable(), null); + } + s_trailByteTable!.Add(strManaged, new TrailByte(trailByte)); + } + internal static unsafe IntPtr ConvertToNative(string strManaged, IntPtr pNativeBuffer) { if (null == strManaged) @@ -237,7 +271,7 @@ internal static unsafe IntPtr ConvertToNative(string strManaged, IntPtr pNativeB } else { - bool hasTrailByte = StubHelpers.TryGetStringTrailByte(strManaged, out byte trailByte); + bool hasTrailByte = TryGetTrailByte(strManaged, out byte trailByte); uint lengthInBytes = (uint)strManaged.Length * 2; @@ -320,8 +354,7 @@ internal static unsafe IntPtr ConvertToNative(string strManaged, IntPtr pNativeB if ((length & 1) == 1) { - // odd-sized strings need to have the trailing byte saved in their sync block - StubHelpers.SetStringTrailByte(ret, ((byte*)bstr)[length - 1]); + SetTrailByte(ret, ((byte*)bstr)[length - 1]); } return ret; @@ -1489,19 +1522,6 @@ internal static void CheckStringLength(uint length) } } - // Try to retrieve the extra byte - returns false if not present. - [MethodImpl(MethodImplOptions.InternalCall)] - internal static extern bool TryGetStringTrailByte(string str, out byte data); - - // Set extra byte for odd-sized strings that came from interop as BSTR. - internal static void SetStringTrailByte(string str, byte data) - { - SetStringTrailByte(new StringHandleOnStack(ref str!), data); - } - - [LibraryImport(RuntimeHelpers.QCall, EntryPoint = "StubHelpers_SetStringTrailByte")] - private static partial void SetStringTrailByte(StringHandleOnStack str, byte data); - internal static unsafe void FmtClassUpdateNativeInternal(object obj, byte* pNative, ref CleanupWorkListElement? pCleanupWorkList) { MethodTable* pMT = RuntimeHelpers.GetMethodTable(obj); diff --git a/src/coreclr/vm/ecalllist.h b/src/coreclr/vm/ecalllist.h index 8b120c54eaf91d..a8b2c26a307b30 100644 --- a/src/coreclr/vm/ecalllist.h +++ b/src/coreclr/vm/ecalllist.h @@ -352,7 +352,6 @@ FCFuncEnd() FCFuncStart(gStubHelperFuncs) FCFuncElement("GetDelegateTarget", StubHelpers::GetDelegateTarget) - FCFuncElement("TryGetStringTrailByte", StubHelpers::TryGetStringTrailByte) FCFuncElement("SetLastError", StubHelpers::SetLastError) FCFuncElement("ClearLastError", StubHelpers::ClearLastError) #ifdef FEATURE_COMINTEROP diff --git a/src/coreclr/vm/object.cpp b/src/coreclr/vm/object.cpp index e7f486c48cb4ac..877a565fda44b6 100644 --- a/src/coreclr/vm/object.cpp +++ b/src/coreclr/vm/object.cpp @@ -658,47 +658,6 @@ STRINGREF StringObject::NewString(INT32 length) { } } - -/*==================================NewString=================================== -**Action: Many years ago, VB didn't have the concept of a byte array, so enterprising -** users created one by allocating a BSTR with an odd length and using it to -** store bytes. A generation later, we're still stuck supporting this behavior. -** The way that we do this is to take advantage of the difference between the -** array length and the string length. The string length will always be the -** number of characters between the start of the string and the terminating 0. -** If we need an odd number of bytes, we'll take one wchar after the terminating 0. -** (e.g. at position StringLength+1). The high-order byte of this wchar is -** reserved for flags and the low-order byte is our odd byte. This function is -** used to allocate a string of that shape, but we don't actually mark the -** trailing byte as being in use yet. -**Returns: A newly allocated string. Null if length is less than 0. -**Arguments: length -- the length of the string to allocate -** bHasTrailByte -- whether the string also has a trailing byte. -**Exceptions: OutOfMemoryException if AllocateString fails. -==============================================================================*/ -STRINGREF StringObject::NewString(INT32 length, BOOL bHasTrailByte) { - CONTRACTL { - GC_TRIGGERS; - MODE_COOPERATIVE; - PRECONDITION(length>=0 && length != INT32_MAX); - } CONTRACTL_END; - - STRINGREF pString; - if (length<0 || length == INT32_MAX) { - return NULL; - } else if (length == 0) { - return GetEmptyString(); - } else { - pString = AllocateString(length); - _ASSERTE(pString->GetBuffer()[length]==0); - if (bHasTrailByte) { - _ASSERTE(pString->GetBuffer()[length+1]==0); - } - } - - return pString; -} - //======================================================================== // Creates a System.String object and initializes from // the supplied null-terminated C string. @@ -887,74 +846,6 @@ STRINGREF* StringObject::InitEmptyStringRefPtr() { return EmptyStringRefPtr; } -/*============================InternalTrailByteCheck============================ -**Action: Many years ago, VB didn't have the concept of a byte array, so enterprising -** users created one by allocating a BSTR with an odd length and using it to -** store bytes. A generation later, we're still stuck supporting this behavior. -** The way that we do this is stick the trail byte in the sync block -** whenever we encounter such a situation. Since we expect this to be a very corner case -** accessing the sync block seems like a good enough solution -** -**Returns: True if str contains a VB trail byte, false otherwise. -**Arguments: str -- The string to be examined. -**Exceptions: None -==============================================================================*/ -BOOL StringObject::HasTrailByte() { - WRAPPER_NO_CONTRACT; - - SyncBlock * pSyncBlock = PassiveGetSyncBlock(); - if(pSyncBlock != NULL) - { - return pSyncBlock->HasCOMBstrTrailByte(); - } - - return FALSE; -} - -/*=================================GetTrailByte================================= -**Action: If str contains a vb trail byte, returns a copy of it. -**Returns: True if str contains a trail byte. *bTrailByte is set to -** the byte in question if str does have a trail byte, otherwise -** it's set to 0. -**Arguments: str -- The string being examined. -** bTrailByte -- An out param to hold the value of the trail byte. -**Exceptions: None. -==============================================================================*/ -BOOL StringObject::GetTrailByte(BYTE *bTrailByte) { - CONTRACTL - { - NOTHROW; - GC_NOTRIGGER; - MODE_ANY; - } - CONTRACTL_END; - _ASSERTE(bTrailByte); - *bTrailByte=0; - - BOOL retValue = HasTrailByte(); - - if(retValue) - { - *bTrailByte = GET_VB_TRAIL_BYTE(GetHeader()->PassiveGetSyncBlock()->GetCOMBstrTrailByte()); - } - - return retValue; -} - -/*=================================SetTrailByte================================= -**Action: Sets the trail byte in the sync block -**Returns: True. -**Arguments: str -- The string into which to set the trail byte. -** bTrailByte -- The trail byte to be added to the string. -**Exceptions: None. -==============================================================================*/ -BOOL StringObject::SetTrailByte(BYTE bTrailByte) { - WRAPPER_NO_CONTRACT; - - GetHeader()->GetSyncBlock()->SetCOMBstrTrailByte(MAKE_VB_TRAIL_BYTE(bTrailByte)); - return TRUE; -} - #ifdef USE_CHECKED_OBJECTREFS //------------------------------------------------------------- diff --git a/src/coreclr/vm/object.h b/src/coreclr/vm/object.h index 08bbcd920f782c..66e7b779d99dc8 100644 --- a/src/coreclr/vm/object.h +++ b/src/coreclr/vm/object.h @@ -832,7 +832,6 @@ class StringObject : public Object // characters and the null terminator you should pass in 5 and NOT 6. //======================================================================== static STRINGREF NewString(int length); - static STRINGREF NewString(int length, BOOL bHasTrailByte); static STRINGREF NewString(const WCHAR *pwsz); static STRINGREF NewString(const WCHAR *pwsz, int length); static STRINGREF NewString(LPCUTF8 psz); @@ -843,10 +842,6 @@ class StringObject : public Object static STRINGREF* InitEmptyStringRefPtr(); - BOOL HasTrailByte(); - BOOL GetTrailByte(BYTE *bTrailByte); - BOOL SetTrailByte(BYTE bTrailByte); - /*=================RefInterpretGetStringValuesDangerousForGC====================== **N.B.: This performs no range checking and relies on the caller to have done this. **Args: (IN)ref -- the String to be interpretted. diff --git a/src/coreclr/vm/olevariant.cpp b/src/coreclr/vm/olevariant.cpp index 61ea83f5a56a15..4f7b3b1622a356 100644 --- a/src/coreclr/vm/olevariant.cpp +++ b/src/coreclr/vm/olevariant.cpp @@ -4247,66 +4247,6 @@ TypeHandle OleVariant::GetElementTypeForRecordSafeArray(SAFEARRAY* pSafeArray) } #endif //FEATURE_COMINTEROP -void OleVariant::AllocateEmptyStringForBSTR(BSTR bstr, STRINGREF *pStringObj) -{ - CONTRACTL - { - THROWS; - GC_TRIGGERS; - MODE_COOPERATIVE; - PRECONDITION(CheckPointer(bstr)); - PRECONDITION(CheckPointer(pStringObj)); - } - CONTRACTL_END; - - // The BSTR isn't null so allocate a managed string of the appropriate length. - ULONG length = SysStringByteLen(bstr); - - if (length > MAX_SIZE_FOR_INTEROP) - COMPlusThrow(kMarshalDirectiveException, IDS_EE_STRING_TOOLONG); - - // Check to see if the BSTR has trailing odd byte. - BOOL bHasTrailByte = ((length%sizeof(WCHAR)) != 0); - length = length / sizeof(WCHAR); - SetObjectReference((OBJECTREF*)pStringObj, (OBJECTREF)StringObject::NewString(length, bHasTrailByte)); -} - -void OleVariant::ConvertContentsBSTRToString(BSTR bstr, STRINGREF *pStringObj) -{ - CONTRACTL - { - THROWS; - GC_TRIGGERS; - MODE_COOPERATIVE; - PRECONDITION(CheckPointer(bstr)); - PRECONDITION(CheckPointer(pStringObj)); - } - CONTRACTL_END; - - // this is the right thing to do, but sometimes we - // end up thinking we're marshaling a BSTR when we're not, because - // it's the default type. - ULONG length = SysStringByteLen((BSTR)bstr); - if (length > MAX_SIZE_FOR_INTEROP) - COMPlusThrow(kMarshalDirectiveException, IDS_EE_STRING_TOOLONG); - - ULONG charLength = length/sizeof(WCHAR); - BOOL hasTrailByte = (length%sizeof(WCHAR) != 0); - - memcpyNoGCRefs((*pStringObj)->GetBuffer(), bstr, charLength*sizeof(WCHAR)); - - if (hasTrailByte) - { - BYTE* buff = (BYTE*)bstr; - //set the trail byte - (*pStringObj)->SetTrailByte(buff[length-1]); - } - - // null terminate the StringRef - WCHAR* wstr = (WCHAR *)(*pStringObj)->GetBuffer(); - wstr[charLength] = '\0'; -} - void OleVariant::ConvertBSTRToString(BSTR bstr, STRINGREF *pStringObj) { CONTRACTL @@ -4326,74 +4266,10 @@ void OleVariant::ConvertBSTRToString(BSTR bstr, STRINGREF *pStringObj) if (bstr == NULL) return; - AllocateEmptyStringForBSTR(bstr, pStringObj); - ConvertContentsBSTRToString(bstr, pStringObj); -} - -BSTR OleVariant::AllocateEmptyBSTRForString(STRINGREF *pStringObj) -{ - CONTRACT(BSTR) - { - THROWS; - GC_NOTRIGGER; - MODE_COOPERATIVE; - PRECONDITION(CheckPointer(pStringObj)); - PRECONDITION(*pStringObj != NULL); - POSTCONDITION(RETVAL != NULL); - } - CONTRACT_END; - - ULONG length = (*pStringObj)->GetStringLength(); - if (length > MAX_SIZE_FOR_INTEROP) - COMPlusThrow(kMarshalDirectiveException, IDS_EE_STRING_TOOLONG); - - length = length*sizeof(WCHAR); - if ((*pStringObj)->HasTrailByte()) - { - length += 1; - } - BSTR bstr = SysAllocStringByteLen(NULL, length); - if (bstr == NULL) - ThrowOutOfMemory(); - - RETURN bstr; -} - -void OleVariant::ConvertContentsStringToBSTR(STRINGREF *pStringObj, BSTR bstr) -{ - CONTRACTL - { - THROWS; - GC_NOTRIGGER; - MODE_COOPERATIVE; - PRECONDITION(CheckPointer(pStringObj)); - PRECONDITION(*pStringObj != NULL); - PRECONDITION(CheckPointer(bstr)); - } - CONTRACTL_END; - - DWORD length = (DWORD)(*pStringObj)->GetStringLength(); - if (length > MAX_SIZE_FOR_INTEROP) - COMPlusThrow(kMarshalDirectiveException, IDS_EE_STRING_TOOLONG); - - BYTE *buff = (BYTE*)bstr; - ULONG byteLen = length * sizeof(WCHAR); - - memcpyNoGCRefs(bstr, (*pStringObj)->GetBuffer(), byteLen); - - if ((*pStringObj)->HasTrailByte()) - { - BYTE b; - BOOL hasTrailB; - hasTrailB = (*pStringObj)->GetTrailByte(&b); - _ASSERTE(hasTrailB); - buff[byteLen] = b; - } - else - { - // copy the null terminator - bstr[length] = W('\0'); - } + PREPARE_NONVIRTUAL_CALLSITE(METHOD__BSTRMARSHALER__CONVERT_TO_MANAGED); + DECLARE_ARGHOLDER_ARRAY(args, 1); + args[ARGNUM_0] = PTR_TO_ARGHOLDER(bstr); + CALL_MANAGED_METHOD_RETREF(*pStringObj, STRINGREF, args); } BSTR OleVariant::ConvertStringToBSTR(STRINGREF *pStringObj) @@ -4401,25 +4277,28 @@ BSTR OleVariant::ConvertStringToBSTR(STRINGREF *pStringObj) CONTRACT(BSTR) { THROWS; - GC_NOTRIGGER; + GC_TRIGGERS; MODE_COOPERATIVE; PRECONDITION(CheckPointer(pStringObj)); // A null BSTR should only be returned if the input string is null. POSTCONDITION(RETVAL != NULL || *pStringObj == NULL); -} + } CONTRACT_END; // Initiatilize the return BSTR value to null. BSTR bstr = NULL; - // If the string object isn't null then we convert it to a BSTR. Otherwise we will return null. - if (*pStringObj != NULL) + if (*pStringObj == NULL) { - bstr = AllocateEmptyBSTRForString(pStringObj); - ConvertContentsStringToBSTR(pStringObj, bstr); + RETURN NULL; } + PREPARE_NONVIRTUAL_CALLSITE(METHOD__BSTRMARSHALER__CONVERT_TO_NATIVE); + DECLARE_ARGHOLDER_ARRAY(args, 2); + args[ARGNUM_0] = STRINGREF_TO_ARGHOLDER(*pStringObj); + args[ARGNUM_1] = PTR_TO_ARGHOLDER(nullptr); + CALL_MANAGED_METHOD(bstr, BSTR, args); RETURN bstr; } #endif // FEATURE_COMINTEROP diff --git a/src/coreclr/vm/olevariant.h b/src/coreclr/vm/olevariant.h index 7d62924db8e935..f4fc1ed357f647 100644 --- a/src/coreclr/vm/olevariant.h +++ b/src/coreclr/vm/olevariant.h @@ -34,15 +34,9 @@ class OleVariant static void MarshalObjectForOleVariant(const VARIANT *pOle, OBJECTREF * const & pObj); static void MarshalOleRefVariantForObject(OBJECTREF *pObj, VARIANT *pOle); - // Helper functions to convert BSTR to managed strings. - static void AllocateEmptyStringForBSTR(BSTR bstr, STRINGREF *pStringObj); - static void ConvertContentsBSTRToString(BSTR bstr, STRINGREF *pStringObj); static void ConvertBSTRToString(BSTR bstr, STRINGREF *pStringObj); - - // Helper functions to convert managed strings to BSTRs. - static BSTR AllocateEmptyBSTRForString(STRINGREF *pStringObj); - static void ConvertContentsStringToBSTR(STRINGREF *pStringObj, BSTR bstr); static BSTR ConvertStringToBSTR(STRINGREF *pStringObj); + static void MarshalObjectForOleVariantUncommon(const VARIANT *pOle, OBJECTREF * const & pObj); static void MarshalOleVariantForObjectUncommon(OBJECTREF * const & pObj, VARIANT *pOle); #endif // FEATURE_COMINTEROP diff --git a/src/coreclr/vm/qcallentrypoints.cpp b/src/coreclr/vm/qcallentrypoints.cpp index ade4b13e843e3f..d19da2d6b3dfac 100644 --- a/src/coreclr/vm/qcallentrypoints.cpp +++ b/src/coreclr/vm/qcallentrypoints.cpp @@ -505,7 +505,6 @@ static const Entry s_QCall[] = DllImportEntry(X86BaseCpuId) #endif DllImportEntry(StubHelpers_CreateCustomMarshaler) - DllImportEntry(StubHelpers_SetStringTrailByte) DllImportEntry(StubHelpers_ThrowInteropParamException) DllImportEntry(StubHelpers_MarshalToManagedVaList) DllImportEntry(StubHelpers_MarshalToUnmanagedVaList) diff --git a/src/coreclr/vm/stubhelpers.cpp b/src/coreclr/vm/stubhelpers.cpp index 4437f432123d05..4a47ba6e40cd56 100644 --- a/src/coreclr/vm/stubhelpers.cpp +++ b/src/coreclr/vm/stubhelpers.cpp @@ -481,29 +481,6 @@ FCIMPL1(void*, StubHelpers::GetDelegateTarget, DelegateObject *pThisUNSAFE) } FCIMPLEND -#include -FCIMPL2(FC_BOOL_RET, StubHelpers::TryGetStringTrailByte, StringObject* thisRefUNSAFE, UINT8 *pbData) -{ - FCALL_CONTRACT; - - STRINGREF thisRef = ObjectToSTRINGREF(thisRefUNSAFE); - FC_RETURN_BOOL(thisRef->GetTrailByte(pbData)); -} -FCIMPLEND -#include - -extern "C" void QCALLTYPE StubHelpers_SetStringTrailByte(QCall::StringHandleOnStack str, UINT8 bData) -{ - QCALL_CONTRACT; - - BEGIN_QCALL; - - GCX_COOP(); - str.Get()->SetTrailByte(bData); - - END_QCALL; -} - extern "C" void QCALLTYPE StubHelpers_ThrowInteropParamException(INT resID, INT paramIdx) { QCALL_CONTRACT; diff --git a/src/coreclr/vm/stubhelpers.h b/src/coreclr/vm/stubhelpers.h index 00fdd97b6d1b64..fb96c20f74a2ec 100644 --- a/src/coreclr/vm/stubhelpers.h +++ b/src/coreclr/vm/stubhelpers.h @@ -33,8 +33,6 @@ class StubHelpers static FCDECL0(void, ClearLastError ); static FCDECL1(void*, GetDelegateTarget, DelegateObject *pThisUNSAFE); - static FCDECL2(FC_BOOL_RET, TryGetStringTrailByte, StringObject* thisRefUNSAFE, UINT8 *pbData); - static FCDECL2(void, LogPinnedArgument, MethodDesc *localDesc, Object *nativeArg); static FCDECL1(DWORD, CalcVaListSize, VARARGS *varargs); }; @@ -56,7 +54,6 @@ extern "C" IUnknown* QCALLTYPE InterfaceMarshaler_ConvertToNative(QCall::ObjectH extern "C" void QCALLTYPE InterfaceMarshaler_ConvertToManaged(IUnknown** ppUnk, MethodTable* pItfMT, MethodTable* pClsMT, DWORD dwFlags, QCall::ObjectHandleOnStack retObject); #endif -extern "C" void QCALLTYPE StubHelpers_SetStringTrailByte(QCall::StringHandleOnStack str, UINT8 bData); extern "C" void QCALLTYPE StubHelpers_ThrowInteropParamException(INT resID, INT paramIdx); extern "C" void QCALLTYPE StubHelpers_MarshalToManagedVaList(va_list va, VARARGS* pArgIterator); diff --git a/src/coreclr/vm/syncblk.h b/src/coreclr/vm/syncblk.h index 04605ab9ec5cb5..f154c9a8aa22cf 100644 --- a/src/coreclr/vm/syncblk.h +++ b/src/coreclr/vm/syncblk.h @@ -441,13 +441,6 @@ class SyncBlock // can never be 0. ObjectNative::GetHashCode in objectnative.cpp makes sure to enforce this. DWORD m_dwHashCode; - // In some early version of VB when there were no arrays developers used to use BSTR as arrays - // The way this was done was by adding a trail byte at the end of the BSTR - // To support this scenario, we need to use the sync block for this special case and - // save the trail character in here. - // This stores the trail character when a BSTR is used as an array - WCHAR m_BSTRTrailByte; - public: SyncBlock(DWORD indx) : m_Lock((OBJECTHANDLE)NULL) @@ -457,7 +450,6 @@ class SyncBlock , m_pEnCInfo(PTR_NULL) #endif // FEATURE_METADATA_UPDATER , m_dwHashCode(0) - , m_BSTRTrailByte(0) { LIMITED_METHOD_CONTRACT; @@ -597,22 +589,6 @@ class SyncBlock SyncBlockPrecious = 0x80000000, }; - BOOL HasCOMBstrTrailByte() - { - LIMITED_METHOD_CONTRACT; - return (m_BSTRTrailByte!=0); - } - WCHAR GetCOMBstrTrailByte() - { - return m_BSTRTrailByte; - } - void SetCOMBstrTrailByte(WCHAR trailByte) - { - WRAPPER_NO_CONTRACT; - m_BSTRTrailByte = trailByte; - SetPrecious(); - } - private: void InitializeThinLock(DWORD recursionLevel, DWORD threadId); diff --git a/src/tests/Interop/PInvoke/Variant/PInvokeDefs.cs b/src/tests/Interop/PInvoke/Variant/PInvokeDefs.cs index 1f1b027b49f838..4742f0adc0f890 100644 --- a/src/tests/Interop/PInvoke/Variant/PInvokeDefs.cs +++ b/src/tests/Interop/PInvoke/Variant/PInvokeDefs.cs @@ -196,4 +196,11 @@ public struct ObjectWrapper public static extern bool Marshal_Struct_ByRef_Empty(ref ObjectWrapper wrapper); [DllImport(nameof(VariantNative))] public static extern bool Marshal_Struct_ByRef_Null(ref ObjectWrapper wrapper); + + [DllImport(nameof(VariantNative))] + public static extern void GetBSTRWithTrailingByteInVariant(out object variant); + + [DllImport(nameof(VariantNative))] + [return: MarshalAs(UnmanagedType.U1)] + public static extern bool VerifyBSTRWithTrailingByteInVariant(object variant); } diff --git a/src/tests/Interop/PInvoke/Variant/VariantNative.cpp b/src/tests/Interop/PInvoke/Variant/VariantNative.cpp index 95c2ba063390de..38184950175390 100644 --- a/src/tests/Interop/PInvoke/Variant/VariantNative.cpp +++ b/src/tests/Interop/PInvoke/Variant/VariantNative.cpp @@ -1112,3 +1112,45 @@ extern "C" BOOL DLL_EXPORT STDMETHODCALLTYPE Marshal_Struct_ByRef_Null(VariantWr return TRUE; } +#define TRAILING_BYTE 0xFF + +extern "C" void DLL_EXPORT STDMETHODCALLTYPE GetBSTRWithTrailingByteInVariant(VARIANT* pVariant) +{ + VARIANT variant; + VariantInit(&variant); + + BSTR bstr = (BSTR)CoreClrBStrAlloc(sizeof(W("Test")) + 1); // 4 characters + trailing byte + null terminator + bstr[0] = W('T'); + bstr[1] = W('e'); + bstr[2] = W('s'); + bstr[3] = W('t'); + bstr[4] = W('\0'); + *(uint8_t*)(bstr + 5) = TRAILING_BYTE; + + variant.vt = VT_BSTR; + variant.bstrVal = bstr; + + *pVariant = variant; +} + +extern "C" BOOL DLL_EXPORT STDMETHODCALLTYPE VerifyBSTRWithTrailingByteInVariant(VARIANT variant) +{ + if (variant.vt != VT_BSTR) + { + return FALSE; + } + + BSTR bstr = variant.bstrVal; + + // Verify the contents of the BSTR + if (bstr[0] != W('T') || bstr[1] != W('e') || bstr[2] != W('s') || bstr[3] != W('t') || bstr[4] != W('\0')) + { + return FALSE; + } + + if (*(uint8_t*)(bstr + 5) != TRAILING_BYTE) + { + return FALSE; + } + return TRUE; +} diff --git a/src/tests/Interop/PInvoke/Variant/VariantTest.BuiltInCom.cs b/src/tests/Interop/PInvoke/Variant/VariantTest.BuiltInCom.cs index 5e3eea6e0a4df5..0d3a9ef1e79381 100644 --- a/src/tests/Interop/PInvoke/Variant/VariantTest.BuiltInCom.cs +++ b/src/tests/Interop/PInvoke/Variant/VariantTest.BuiltInCom.cs @@ -30,6 +30,7 @@ public static int TestEntryPoint() TestOut(); TestFieldByValue(!builtInComDisabled); TestFieldByRef(!builtInComDisabled); + TestBSTRWithTrailingByte(); } catch (Exception e) { diff --git a/src/tests/Interop/PInvoke/Variant/VariantTest.ComWrappers.cs b/src/tests/Interop/PInvoke/Variant/VariantTest.ComWrappers.cs index 156463443a8881..f83f5a4f263102 100644 --- a/src/tests/Interop/PInvoke/Variant/VariantTest.ComWrappers.cs +++ b/src/tests/Interop/PInvoke/Variant/VariantTest.ComWrappers.cs @@ -27,6 +27,7 @@ public static int TestEntryPoint() TestOut(); TestFieldByValue(testComMarshal); TestFieldByRef(testComMarshal); + TestBSTRWithTrailingByte(); } catch (Exception e) { diff --git a/src/tests/Interop/PInvoke/Variant/VariantTest.cs b/src/tests/Interop/PInvoke/Variant/VariantTest.cs index 9b1b80393d0e63..6837ac005fa822 100644 --- a/src/tests/Interop/PInvoke/Variant/VariantTest.cs +++ b/src/tests/Interop/PInvoke/Variant/VariantTest.cs @@ -378,4 +378,20 @@ private unsafe static void TestFieldByRef(bool hasComSupport) }); } } + + private unsafe static void TestBSTRWithTrailingByte() + { + // Get a VARIANT containing a BSTR from native code that has a trailing byte after the null terminator + GetBSTRWithTrailingByteInVariant(out object variant); + + // The VARIANT should unmarshal as a string + Assert.IsType(variant); + string bstr = (string)variant; + + // Verify the string content is correct + Assert.Equal("Test", bstr); + + // Pass the string back to native code wrapped in a VARIANT to verify the trailing byte is preserved + Assert.True(VerifyBSTRWithTrailingByteInVariant(new BStrWrapper(bstr))); + } }