Skip to content

Commit

Permalink
feat: multiple inherited constructors in abstract contracts (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
agusduha authored Apr 17, 2024
1 parent b434a9e commit bf1a15f
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 16 deletions.
27 changes: 27 additions & 0 deletions solidity/contracts/utils/ContractI.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
pragma solidity ^0.8.0;

contract Ownable {
address public owner;

constructor(address _owner) {
owner = _owner;
}
}

contract Pausable {
bool public paused;

constructor(bool _paused) {
paused = _paused;
}
}

abstract contract ContractI is Ownable, Pausable {}

abstract contract ContractI2 is Ownable, Pausable {
bool public boolean;

constructor(bool _boolean) {
boolean = _boolean;
}
}
10 changes: 5 additions & 5 deletions src/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ import {
extractOverrides,
hasNestedMappings,
extractStructFieldsNames,
extractConstructorsParameters,
} from './utils';
import { ContractDefinition, FunctionDefinition, VariableDeclaration, Identifier, ImportDirective } from 'solc-typed-ast';
import { FunctionDefinition, VariableDeclaration, Identifier, ImportDirective } from 'solc-typed-ast';

export function internalFunctionContext(node: FunctionDefinition): InternalFunctionContext {
// Check if the function is internal
Expand Down Expand Up @@ -107,13 +108,12 @@ export function externalOrPublicFunctionContext(node: FunctionDefinition): Exter
export function constructorContext(node: FunctionDefinition): ConstructorContext {
if (!node.isConstructor) throw new Error('The node is not a constructor');

// Get the parameters of the constructor, if there are no parameters then we use an empty array
const { functionParameters: parameters, parameterNames } = extractParameters(node.vParameters.vParameters);
// Get the parameters of the constructors, if there are no parameters then we use an empty array
const { parameters, contracts } = extractConstructorsParameters(node as FullFunctionDefinition);

return {
parameters: parameters.join(', '),
parameterNames: parameterNames.join(', '),
contractName: (node.vScope as ContractDefinition).name,
contracts: Array.from(contracts).join(' '),
};
}

Expand Down
2 changes: 1 addition & 1 deletion src/templates/partials/constructor.hbs
Original file line number Diff line number Diff line change
@@ -1 +1 @@
constructor({{parameters}}) {{contractName}}({{parameterNames}}) {}
constructor({{parameters}}) {{contracts}} {}
4 changes: 2 additions & 2 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ export const explicitTypes = ['string', 'bytes', 'mapping', 'struct'];
// Contexts to pass to Handlebars templates
export interface ConstructorContext {
parameters: string;
parameterNames: string;
contractName: string;
contracts: string;
}

export interface ExternalFunctionContext {
Expand Down Expand Up @@ -89,6 +88,7 @@ interface Selector {
implemented: boolean;
contracts?: Set<string>;
function?: FunctionDefinition;
constructors?: FunctionDefinition[];
}

export interface SelectorsMap {
Expand Down
54 changes: 54 additions & 0 deletions src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ export function smockableNode(node: ASTNode): boolean {
if (node.constant || node.mutability === 'immutable') return false;
// If the state variable is private we don't mock it
if (node.visibility === 'private') return false;
} else if (node instanceof FunctionDefinition) {
if (node.isConstructor && (node.parent as ContractDefinition)?.abstract) return false;
} else if (!(node instanceof FunctionDefinition)) {
// Only process variables and functions
return false;
Expand All @@ -314,6 +316,17 @@ export async function renderAbstractUnimplementedFunctions(contract: ContractDef
const currentSelectors = [...contract.vStateVariables, ...contract.vFunctions].map((node) => node.raw?.functionSelector);
const inheritedSelectors = getAllInheritedSelectors(contract);

// If the abstract contract has a constructor, we need to add it to the selectors
if (contract.vConstructor) {
const constructors = inheritedSelectors['constructor'];
inheritedSelectors['constructor'] = {
implemented: constructors.implemented,
function: contract.vConstructor,
contracts: constructors.contracts ? constructors.contracts.add(contract.name) : new Set([contract.name]),
constructors: constructors.constructors ? [...constructors.constructors, contract.vConstructor] : [contract.vConstructor],
};
}

for (const selector in inheritedSelectors) {
// Skip the functions that are already implemented in the current contract
if (currentSelectors.includes(selector)) continue;
Expand Down Expand Up @@ -359,11 +372,13 @@ export const getAllInheritedSelectors = (contract: ContractDefinition, selectors

const contracts = selectors[selector]?.contracts;
const isImplemented = selectors[selector]?.implemented;
const constructors = selectors[selector]?.constructors || [];

selectors[selector] = {
implemented: isImplemented || (!func.isConstructor && func.implemented),
contracts: contracts ? contracts.add(base.name) : new Set([base.name]),
function: func,
constructors: func.isConstructor ? constructors.concat(func) : constructors,
};
}

Expand Down Expand Up @@ -462,3 +477,42 @@ export const extractStructFieldsNames = (node: TypeName): string[] | null => {

return fields.map((field) => (field as VariableDeclaration).name).filter((name) => name);
};

/**
* Extracts the parameters of the constructors of a contract
* @param node The function to extract the constructors parameters from
* @returns The parameters and contracts of the constructors
*/
export const extractConstructorsParameters = (
node: FullFunctionDefinition,
): {
parameters: string[];
contracts: string[];
} => {
let constructors: FunctionDefinition[];

if (node?.selectors?.['constructor']?.constructors?.length > 1) {
constructors = node.selectors['constructor'].constructors;
} else {
constructors = [node];
}

const allParameters: string[] = [];
const allContracts: string[] = [];

for (const func of constructors) {
const { functionParameters: parameters, parameterNames } = extractParameters(func.vParameters.vParameters);
const contractName = (func.vScope as ContractDefinition).name;
const contractValue = `${contractName}(${parameterNames.join(', ')})`;

if (allContracts.includes(contractValue)) continue;

allParameters.push(...parameters);
allContracts.push(contractValue);
}

return {
parameters: allParameters,
contracts: allContracts,
};
};
50 changes: 42 additions & 8 deletions test/unit/context/constructorContext.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { expect } from 'chai';
import { mockContractDefinition, mockFunctionDefinition, mockParameterList, mockVariableDeclaration } from '../../mocks';
import { constructorContext } from '../../../src/context';
import { DataLocation } from 'solc-typed-ast';
import { FullFunctionDefinition, SelectorsMap } from '../../../src/types';

describe('constructorContext', () => {
const defaultAttributes = {
Expand All @@ -21,9 +22,8 @@ describe('constructorContext', () => {
const context = constructorContext(node);

expect(context).to.eql({
contractName: 'TestContract',
parameters: '',
parameterNames: '',
contracts: 'TestContract()',
});
});

Expand All @@ -36,9 +36,8 @@ describe('constructorContext', () => {
const context = constructorContext(node);

expect(context).to.eql({
contractName: 'TestContract',
parameters: 'uint256 a, boolean b',
parameterNames: 'a, b',
contracts: 'TestContract(a, b)',
});
});

Expand All @@ -48,9 +47,8 @@ describe('constructorContext', () => {
const context = constructorContext(node);

expect(context).to.eql({
contractName: 'TestContract',
parameters: 'uint256 _param0, boolean _param1',
parameterNames: '_param0, _param1',
contracts: 'TestContract(_param0, _param1)',
});
});

Expand All @@ -63,9 +61,45 @@ describe('constructorContext', () => {
const context = constructorContext(node);

expect(context).to.eql({
contractName: 'TestContract',
parameters: 'uint256 memory a, boolean calldata b',
parameterNames: 'a, b',
contracts: 'TestContract(a, b)',
});
});

it('processes inherited constructors', () => {
const parameters = [
mockVariableDeclaration({ name: 'a', typeString: 'uint256', storageLocation: DataLocation.Memory }),
mockVariableDeclaration({ name: 'b', typeString: 'boolean', storageLocation: DataLocation.CallData }),
];

const nodeA = mockFunctionDefinition({
...defaultAttributes,
vParameters: mockParameterList({ vParameters: parameters }),
vScope: mockContractDefinition({ name: 'TestContractA' }),
});

const nodeB = mockFunctionDefinition({
...defaultAttributes,
vParameters: mockParameterList({ vParameters: parameters }),
vScope: mockContractDefinition({ name: 'TestContractB' }),
});

const selectors: SelectorsMap = {
constructor: {
implemented: false,
contracts: new Set(['TestContractA', 'TestContractB']),
constructors: [nodeA, nodeB],
},
};

const nodeWithSelectors = nodeA as FullFunctionDefinition;
nodeWithSelectors.selectors = selectors;

const context = constructorContext(nodeWithSelectors);

expect(context).to.eql({
parameters: 'uint256 memory a, boolean calldata b, uint256 memory a, boolean calldata b',
contracts: 'TestContractA(a, b) TestContractB(a, b)',
});
});
});

0 comments on commit bf1a15f

Please sign in to comment.