@@ -139,9 +139,11 @@ inline Type* LegalizedIntVectorType(const Module& M, Type* ty)
139
139
}
140
140
141
141
// Returns true for structs smaller than 'structSize' and only contains primitive types
142
- inline bool isLegalStructType (const Module& M, StructType* sTy , unsigned structSize)
142
+ inline bool isLegalStructType (const Module& M, Type* ty , unsigned structSize)
143
143
{
144
+ IGC_ASSERT (ty->isStructTy ());
144
145
const DataLayout& DL = M.getDataLayout ();
146
+ StructType* sTy = dyn_cast<StructType>(ty);
145
147
if (sTy && DL.getStructLayout (sTy )->getSizeInBits () <= structSize)
146
148
{
147
149
for (const auto * EltTy : sTy ->elements ())
@@ -161,7 +163,7 @@ inline bool isLegalSignatureType(const Module& M, Type* ty, bool isStackCall)
161
163
{
162
164
if (ty->isStructTy ())
163
165
{
164
- return isLegalStructType (M, cast<StructType>(ty) , MAX_STRUCT_SIZE_IN_BITS);
166
+ return isLegalStructType (M, ty , MAX_STRUCT_SIZE_IN_BITS);
165
167
}
166
168
else if (ty->isArrayTy ())
167
169
{
@@ -172,14 +174,16 @@ inline bool isLegalSignatureType(const Module& M, Type* ty, bool isStackCall)
172
174
return true ;
173
175
}
174
176
175
- inline bool isPromotableStructType (const Module& M, Type* pointeeType, bool isStackCall)
177
+ // Check if a struct pointer argument is promotable to pass-by-value
178
+ inline bool isPromotableStructType (const Module& M, const Type* ty, bool isStackCall, bool isReturnValue = false )
176
179
{
177
180
if (IGC_IS_FLAG_DISABLED (EnableByValStructArgPromotion))
178
181
return false ;
182
+
179
183
const unsigned int maxSize = isStackCall ? MAX_STRUCT_SIZE_IN_BITS : MAX_SUBROUTINE_STRUCT_SIZE_IN_BITS;
180
- if (isa<StructType>(pointeeType ))
184
+ if (ty-> isPointerTy () && IGCLLVM::getNonOpaquePtrEltTy (ty)-> isStructTy ( ))
181
185
{
182
- return isLegalStructType (M, cast<StructType>(pointeeType ), maxSize);
186
+ return isLegalStructType (M, IGCLLVM::getNonOpaquePtrEltTy (ty ), maxSize);
183
187
}
184
188
return false ;
185
189
}
@@ -190,29 +194,18 @@ inline bool FunctionHasPromotableSRetArg(const Module& M, const Function* F)
190
194
if (F->getReturnType ()->isVoidTy () &&
191
195
!F->arg_empty () &&
192
196
F->arg_begin ()->hasStructRetAttr () &&
193
- isPromotableStructType (M, F->arg_begin ()->getParamStructRetType (), F->hasFnAttribute (" visaStackCall" )))
197
+ isPromotableStructType (M, F->arg_begin ()->getType (), F->hasFnAttribute (" visaStackCall" ), true ))
194
198
{
195
199
return true ;
196
200
}
197
201
return false ;
198
202
}
199
203
200
204
// Promotes struct pointer to struct type
201
- inline StructType * PromotedStructValueType (const Module& M, const Argument* arg )
205
+ inline Type * PromotedStructValueType (const Module& M, const Type* ty )
202
206
{
203
- if (arg->getType ()->isPointerTy ())
204
- {
205
- if (arg->hasStructRetAttr () && arg->getParamStructRetType ()->isStructTy ())
206
- {
207
- return cast<StructType>(arg->getParamStructRetType ());
208
- }
209
- else if (arg->hasByValAttr () && arg->getParamByValType ()->isStructTy ())
210
- {
211
- return cast<StructType>(arg->getParamByValType ());
212
- }
213
- }
214
- IGC_ASSERT_MESSAGE (0 , " Not implemented case" );
215
- return nullptr ;
207
+ IGC_ASSERT (ty->isPointerTy () && IGCLLVM::getNonOpaquePtrEltTy (ty)->isStructTy ());
208
+ return cast<StructType>(IGCLLVM::getNonOpaquePtrEltTy (ty));
216
209
}
217
210
218
211
// BE does not handle struct load/store, so instead store each element of the struct value to the GEP of the struct pointer
@@ -225,7 +218,7 @@ inline void StoreToStruct(IGCLLVM::IRBuilder<>& builder, Value* strVal, Value* s
225
218
for (unsigned i = 0 ; i < sTy ->getNumElements (); i++)
226
219
{
227
220
Value* indices[] = { builder.getInt32 (0 ), builder.getInt32 (i) };
228
- Value* elementPtr = builder.CreateInBoundsGEP (strVal-> getType (), strPtr, indices);
221
+ Value* elementPtr = builder.CreateInBoundsGEP (strPtr, indices);
229
222
Value* element = builder.CreateExtractValue (strVal, i);
230
223
builder.CreateStore (element, elementPtr);
231
224
}
@@ -242,7 +235,7 @@ inline Value* LoadFromStruct(IGCLLVM::IRBuilder<>& builder, Value* strPtr, Type*
242
235
for (unsigned i = 0 ; i < sTy ->getNumElements (); i++)
243
236
{
244
237
Value* indices[] = { builder.getInt32 (0 ), builder.getInt32 (i) };
245
- Value* elementPtr = builder.CreateInBoundsGEP (ty, strPtr, indices);
238
+ Value* elementPtr = builder.CreateInBoundsGEP (strPtr, indices);
246
239
Value* element = builder.CreateLoad (sTy ->getElementType (i), elementPtr);
247
240
strVal = builder.CreateInsertValue (strVal, element, i);
248
241
}
@@ -315,10 +308,10 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
315
308
argTypes.push_back (LegalizedIntVectorType (M, ai->getType ()));
316
309
}
317
310
else if (ai->hasByValAttr () &&
318
- isPromotableStructType (M, ai->getParamByValType (), isStackCall))
311
+ isPromotableStructType (M, ai->getType (), isStackCall))
319
312
{
320
313
fixArgType = true ;
321
- argTypes.push_back (PromotedStructValueType (M, ai));
314
+ argTypes.push_back (PromotedStructValueType (M, ai-> getType () ));
322
315
}
323
316
else if (!isLegalSignatureType (M, ai->getType (), isStackCall))
324
317
{
@@ -336,7 +329,7 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
336
329
// Clone function with new signature
337
330
Type* returnType =
338
331
retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy (M.getContext ()) :
339
- retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, pFunc->arg_begin ()) :
332
+ retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, pFunc->arg_begin ()-> getType () ) :
340
333
retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType (M, pFunc->getReturnType ()) :
341
334
pFunc->getReturnType ();
342
335
FunctionType* signature = FunctionType::get (returnType, argTypes, false );
@@ -400,12 +393,13 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
400
393
if (OldArgIt == pFunc->arg_begin () && retTypeOption == ReturnOpt::RETURN_STRUCT)
401
394
{
402
395
// Create a temp alloca to map the old argument. This will be removed later by SROA.
403
- tempAllocaForSRetPointerTy = PromotedStructValueType (M, OldArgIt);
396
+ tempAllocaForSRetPointerTy = PromotedStructValueType (M, OldArgIt-> getType () );
404
397
tempAllocaForSRetPointer = builder.CreateAlloca (tempAllocaForSRetPointerTy);
405
398
tempAllocaForSRetPointer = builder.CreateAddrSpaceCast (tempAllocaForSRetPointer, OldArgIt->getType ());
406
399
VMap[&*OldArgIt] = tempAllocaForSRetPointer;
407
400
continue ;
408
401
}
402
+
409
403
NewArgIt->setName (OldArgIt->getName ());
410
404
if (!isLegalIntVectorType (M, OldArgIt->getType ()))
411
405
{
@@ -414,25 +408,24 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
414
408
VMap[&*OldArgIt] = trunc ;
415
409
}
416
410
else if (OldArgIt->hasByValAttr () &&
417
- isPromotableStructType (M, OldArgIt->getParamByValType (), isStackCall))
411
+ isPromotableStructType (M, OldArgIt->getType (), isStackCall))
418
412
{
419
- AllocaInst* newArgPtr = builder.CreateAlloca (OldArgIt->getParamByValType ());
420
413
// remove "byval" attrib since it is now pass-by-value
421
414
NewArgIt->removeAttr (llvm::Attribute::ByVal);
415
+ Value* newArgPtr = builder.CreateAlloca (NewArgIt->getType ());
422
416
StoreToStruct (builder, &*NewArgIt, newArgPtr);
423
417
// cast back to original addrspace
424
418
IGC_ASSERT (OldArgIt->getType ()->getPointerAddressSpace () == ADDRESS_SPACE_GENERIC ||
425
- OldArgIt->getType ()->getPointerAddressSpace () == ADDRESS_SPACE_PRIVATE);
426
- llvm::Value* castedNewArgPtr = builder.CreateAddrSpaceCast (newArgPtr, OldArgIt->getType ());
427
- VMap[&*OldArgIt] = castedNewArgPtr ;
419
+ OldArgIt->getType ()->getPointerAddressSpace () == ADDRESS_SPACE_PRIVATE);
420
+ newArgPtr = builder.CreateAddrSpaceCast (newArgPtr, OldArgIt->getType ());
421
+ VMap[&*OldArgIt] = newArgPtr ;
428
422
}
429
423
else if (!isLegalSignatureType (M, OldArgIt->getType (), isStackCall))
430
424
{
431
425
// Load from pointer arg
432
- Value* load = builder.CreateLoad (OldArgIt-> getType (), &*NewArgIt);
426
+ Value* load = builder.CreateLoad (&*NewArgIt);
433
427
VMap[&*OldArgIt] = load;
434
- llvm::Attribute byValAttr = llvm::Attribute::getWithByValType (M.getContext (), OldArgIt->getType ());
435
- NewArgIt->addAttr (byValAttr);
428
+ ArgByVal.push_back (&*NewArgIt);
436
429
}
437
430
else
438
431
{
@@ -451,13 +444,21 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
451
444
builder.CreateBr (ClonedEntryBB);
452
445
MergeBlockIntoPredecessor (ClonedEntryBB);
453
446
447
+ // Loop through new args and add 'byval' attributes
448
+ for (auto arg : ArgByVal)
449
+ {
450
+ arg->addAttr (llvm::Attribute::getWithByValType (M.getContext (),
451
+ IGCLLVM::getNonOpaquePtrEltTy (arg->getType ())));
452
+ }
453
+
454
454
// Now fix the return values
455
455
if (retTypeOption == ReturnOpt::RETURN_BY_REF)
456
456
{
457
457
// Add the 'noalias' and 'sret' attribute to arg0
458
458
auto retArg = pNewFunc->arg_begin ();
459
459
retArg->addAttr (llvm::Attribute::NoAlias);
460
- retArg->addAttr (llvm::Attribute::getWithStructRetType (M.getContext (), pFunc->getReturnType ()));
460
+ retArg->addAttr (llvm::Attribute::getWithStructRetType (
461
+ M.getContext (), IGCLLVM::getNonOpaquePtrEltTy (retArg->getType ())));
461
462
462
463
// Loop through all return instructions and store the old return value into the arg0 pointer
463
464
const auto ptrSize = DL.getPointerSize ();
@@ -576,7 +577,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
576
577
if (callInst->getType ()->isVoidTy () &&
577
578
IGCLLVM::getNumArgOperands (callInst) > 0 &&
578
579
callInst->paramHasAttr (0 , llvm::Attribute::StructRet) &&
579
- isPromotableStructType (M, callInst->getParamAttr ( 0 , llvm::Attribute::StructRet). getValueAsType (), isStackCall))
580
+ isPromotableStructType (M, callInst->getArgOperand ( 0 )-> getType (), isStackCall, true /* retval */ ))
580
581
{
581
582
opNum++; // Skip the first call operand
582
583
retTypeOption = ReturnOpt::RETURN_STRUCT;
@@ -607,17 +608,18 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
607
608
{
608
609
// extend the illegal int to a legal type
609
610
IGCLLVM::IRBuilder<> builder (callInst);
610
- Value* extend = builder.CreateZExt (callInst-> getOperand (opNum) , LegalizedIntVectorType (M, arg->getType ()));
611
+ Value* extend = builder.CreateZExt (arg , LegalizedIntVectorType (M, arg->getType ()));
611
612
callArgs.push_back (extend );
612
613
ArgAttrVec.push_back (AttributeSet ());
613
614
fixArgType = true ;
614
615
}
615
616
else if (callInst->paramHasAttr (opNum, llvm::Attribute::ByVal) &&
616
- isPromotableStructType (M, callInst-> getParamByValType (opNum ), isStackCall))
617
+ isPromotableStructType (M, arg-> getType ( ), isStackCall))
617
618
{
618
619
// Map the new operand to the loaded value of the struct pointer
619
620
IGCLLVM::IRBuilder<> builder (callInst);
620
- Value* newOp = LoadFromStruct (builder, callInst->getOperand (opNum), callInst->getParamByValType (opNum));
621
+ Argument* callArg = IGCLLVM::getArg (*calledFunc, opNum);
622
+ Value* newOp = LoadFromStruct (builder, arg, callArg->getParamByValType ());
621
623
callArgs.push_back (newOp);
622
624
ArgAttrVec.push_back (AttributeSet ());
623
625
fixArgType = true ;
@@ -627,7 +629,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
627
629
// Create and store operand as an alloca, then pass as argument
628
630
IGCLLVM::IRBuilder<> builder (callInst);
629
631
Value* allocaV = builder.CreateAlloca (arg->getType ());
630
- builder.CreateStore (callInst-> getOperand (opNum) , allocaV);
632
+ builder.CreateStore (arg , allocaV);
631
633
callArgs.push_back (allocaV);
632
634
auto byValAttr = llvm::Attribute::getWithByValType (M.getContext (), arg->getType ());
633
635
auto argAttrs = AttributeSet::get (M.getContext (), { byValAttr });
@@ -657,7 +659,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
657
659
}
658
660
Type* retType =
659
661
retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy (callInst->getContext ()) :
660
- retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, callInst->getFunction ( )->getArg ( 0 )) :
662
+ retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, callInst->getArgOperand ( 0 )->getType ( )) :
661
663
retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType (M, callInst->getType ()) :
662
664
callInst->getType ();
663
665
newFnTy = FunctionType::get (retType, argTypes, false );
0 commit comments