@@ -232,8 +232,10 @@ namespace vast::conv {
232
232
).getResult (0 );
233
233
}
234
234
235
+ using value_range = mlir::ValueRange;
236
+
235
237
std::vector< mlir_value > convert_value_types (
236
- mlir::ValueRange values, mlir::TypeRange types, auto &rewriter
238
+ value_range values, mlir::TypeRange types, auto &rewriter
237
239
) {
238
240
std::vector< mlir_value > out;
239
241
out.reserve (values.size ());
@@ -248,6 +250,33 @@ namespace vast::conv {
248
250
return out;
249
251
}
250
252
253
+ std::vector< mlir_value > realized_operand_values (value_range values) {
254
+ std::vector< mlir_value > out;
255
+ out.reserve (values.size ());
256
+ for (auto val : values) {
257
+ if (auto cast = mlir::dyn_cast< mlir::UnrealizedConversionCastOp >(val.getDefiningOp ())) {
258
+ out.push_back (cast.getOperand (0 ));
259
+ } else {
260
+ out.push_back (val);
261
+ }
262
+ }
263
+ return out;
264
+ }
265
+
266
+ mlir_type join (mlir_type lhs, mlir_type rhs) {
267
+ if (!lhs)
268
+ return rhs;
269
+ return lhs == rhs ? lhs : pr::MaybeDataType::get (lhs.getContext ());
270
+ }
271
+
272
+ mlir_type top_type (value_range values) {
273
+ mlir_type ty;
274
+ for (auto val : values) {
275
+ ty = join (ty, val.getType ());
276
+ }
277
+ return ty;
278
+ }
279
+
251
280
template < typename op_t >
252
281
struct parser_conversion_pattern_base
253
282
: mlir_pattern_mixin< operation_conversion_pattern< op_t > >
@@ -306,6 +335,42 @@ namespace vast::conv {
306
335
}
307
336
};
308
337
338
+ template < typename op_t >
339
+ struct ToMaybeParse : operation_conversion_pattern< op_t >
340
+ {
341
+ using base = operation_conversion_pattern< op_t >;
342
+ using base::base;
343
+
344
+ using adaptor_t = typename op_t ::Adaptor;
345
+
346
+ logical_result matchAndRewrite (
347
+ op_t op, adaptor_t adaptor, conversion_rewriter &rewriter
348
+ ) const override {
349
+ auto args = realized_operand_values (adaptor.getOperands ());
350
+ auto rty = top_type (args);
351
+
352
+ auto converted = [&] () -> operation {
353
+ auto matches_return_type = [rty] (auto val) { return val.getType () == rty; };
354
+ if (mlir::isa< pr::NoDataType >(rty) && llvm::all_of (args, matches_return_type))
355
+ return rewriter.create < pr::NoParse >(op.getLoc (), rty, args);
356
+ return rewriter.create < pr::MaybeParse >(op.getLoc (), rty, args);
357
+ } ();
358
+
359
+ rewriter.replaceOpWithNewOp < mlir::UnrealizedConversionCastOp >(
360
+ op, op.getType (), converted->getResult (0 )
361
+ );
362
+
363
+ return mlir::success ();
364
+ }
365
+
366
+ static void legalize (parser_conversion_config &cfg) {
367
+ cfg.target .addLegalOp < pr::MaybeParse >();
368
+ cfg.target .addLegalOp < pr::NoParse >();
369
+ cfg.target .addLegalOp < mlir::UnrealizedConversionCastOp >();
370
+ cfg.target .addIllegalOp < op_t >();
371
+ }
372
+ };
373
+
309
374
struct CallConversion : parser_conversion_pattern_base< hl::CallOp >
310
375
{
311
376
using op_t = hl::CallOp;
@@ -572,6 +637,8 @@ namespace vast::conv {
572
637
ToNoParse< hl::ImplicitCastOp >,
573
638
ToNoParse< hl::CmpOp >, ToNoParse< hl::FCmpOp >,
574
639
// Integer arithmetic
640
+ ToMaybeParse< hl::AddIOp >, ToMaybeParse< hl::SubIOp >,
641
+ // Non-parsing integer arithmetic operations
575
642
ToNoParse< hl::MulIOp >,
576
643
ToNoParse< hl::DivSOp >, ToNoParse< hl::DivUOp >,
577
644
ToNoParse< hl::RemSOp >, ToNoParse< hl::RemUOp >,
0 commit comments