Skip to content

Commit

Permalink
Use enums for basic and preset types in StarknetTypedData (#180)
Browse files Browse the repository at this point in the history
  • Loading branch information
DelevoXDG authored Apr 8, 2024
1 parent 0600ddd commit 00b7939
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 53 deletions.
25 changes: 25 additions & 0 deletions Sources/Starknet/Data/TypedData/BasicType.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import Foundation

extension StarknetTypedData {
enum BasicType: String, CaseIterable {
case felt
case bool
case selector
case string
case u128
case i128
case `enum`
case merkletree
case contractAddress = "ContractAddress"
case classHash = "ClassHash"
case timestamp
case shortstring

static func cases(revision: Revision) -> [BasicType] {
switch revision {
case .v0: [.felt, .bool, .selector, .string, .merkletree]
case .v1: allCases
}
}
}
}
33 changes: 33 additions & 0 deletions Sources/Starknet/Data/TypedData/PresetType.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import Foundation

public extension StarknetTypedData {
enum PresetType: String, Equatable, CaseIterable {
case u256
case tokenAmount = "TokenAmount"
case nftId = "NftId"

public var params: [TypeDeclarationWrapper] {
switch self {
case .u256:
[TypeDeclarationWrapper.standard(.init(name: "low", type: BasicType.u128.rawValue)), TypeDeclarationWrapper.standard(.init(name: "high", type: BasicType.u128.rawValue))]
case .tokenAmount:
[TypeDeclarationWrapper.standard(.init(name: "token_address", type: BasicType.contractAddress.rawValue)), TypeDeclarationWrapper.standard(.init(name: "amount", type: Self.u256.rawValue))]
case .nftId:
[TypeDeclarationWrapper.standard(.init(name: "collection_address", type: BasicType.contractAddress.rawValue)), TypeDeclarationWrapper.standard(.init(name: "token_id", type: Self.u256.rawValue))]
}
}

fileprivate enum CodingKeys: CodingKey {
case u256
case tokenAmount
case nftId
}

static func cases(revision: Revision) -> [PresetType] {
switch revision {
case .v0: []
case .v1: allCases
}
}
}
}
72 changes: 19 additions & 53 deletions Sources/Starknet/Data/TypedData/StarknetTypedData.swift
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
self.revision = try domain.resolveRevision()

self.allTypes = self.types.merging(
Self.getPresetTypes(revision: self.revision),
Self.PresetType.cases(revision: self.revision).reduce(into: [:]) { $0[$1.rawValue] = $1.params },
uniquingKeysWith: { current, _ in current }
)

Expand Down Expand Up @@ -141,8 +141,8 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
throw StarknetTypedDataError.dependencyNotDefined(domain.separatorName)
}

let basicTypes = Self.getBasicTypes(revision: revision)
let presetTypes = Self.getPresetTypes(revision: revision)
let basicTypes = Self.BasicType.cases(revision: revision).map(\.rawValue)
let presetTypes = Self.PresetType.cases(revision: revision).reduce(into: [:]) { $0[$1.rawValue] = $1.params }

let referencedTypes = try Set(types.values.flatMap { type in
try type.flatMap { param in
Expand Down Expand Up @@ -189,7 +189,7 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
switch param {
case let .enum(enumType):
guard revision == .v1 else {
throw StarknetTypedDataError.unsupportedType("enum")
throw StarknetTypedDataError.unsupportedType(BasicType.enum.rawValue)
}
return [enumType.contains]
default:
Expand Down Expand Up @@ -240,7 +240,7 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
switch param {
case let .enum(enumType):
guard revision == .v1 else {
throw StarknetTypedDataError.unsupportedType("enum")
throw StarknetTypedDataError.unsupportedType(BasicType.enum.rawValue)
}
return enumType.contains
default:
Expand All @@ -249,7 +249,7 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
}
func encodeEnumTypes(from type: String) throws -> String {
guard revision == .v1 else {
throw StarknetTypedDataError.unsupportedType("enum")
throw StarknetTypedDataError.unsupportedType(BasicType.enum.rawValue)
}

let enumTypes = try type.extractEnumTypes().map(escape).joined(separator: ",")
Expand Down Expand Up @@ -298,30 +298,33 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
return hashArray(hashes)
}

switch (typeName, revision) {
case ("felt", _), ("string", .v0), ("shortstring", .v1), ("ContractAddress", .v1), ("ClassHash", .v1):
let basicType = BasicType(rawValue: typeName)
switch (basicType, revision) {
case (.felt, _), (.string, .v0), (.shortstring, .v1), (.contractAddress, .v1), (.classHash, .v1):
return try unwrapFelt(from: element)
case ("u128", .v1), ("timestamp", .v1):
case (.u128, .v1), (.timestamp, .v1):
return try unwrapU128(from: element)
case ("i128", .v1):
case (.i128, .v1):
return try unwrapI128(from: element)
case ("bool", _):
case (.bool, _):
return try unwrapBool(from: element)
case ("string", .v1):
case (.string, .v1):
return try hashArray(unwrapLongString(from: element))
case ("selector", _):
case (.selector, _):
return try unwrapSelector(from: element)
case ("enum", .v1):
case (.enum, .v1):
guard let context else {
throw StarknetTypedDataError.contextNotDefined
}
return try unwrapEnum(from: element, context: context)
case ("merkletree", _):
case (.merkletree, _):
guard let context else {
throw StarknetTypedDataError.contextNotDefined
}
return try prepareMerkleTreeRoot(from: element, context: context)
default:
case (.some, .v0):
throw StarknetTypedDataError.unsupportedType(typeName)
case (nil, _):
throw StarknetTypedDataError.dependencyNotDefined(typeName)
}
}
Expand Down Expand Up @@ -486,43 +489,6 @@ public extension StarknetTypedData {
}
}

private extension StarknetTypedData {
static let basicTypesV0: Set = ["felt", "bool", "string", "selector", "merkletree"]
static let basicTypesV1: Set = basicTypesV0.union(["enum", "u128", "i128", "ContractAddress", "ClassHash", "timestamp", "shortstring"])

static let presetTypesV1 = [
"u256": [
StandardType(name: "low", type: "u128"),
StandardType(name: "high", type: "u128"),
],
"TokenAmount": [
StandardType(name: "token_address", type: "ContractAddress"),
StandardType(name: "amount", type: "u256"),
],
"NftId": [
StandardType(name: "collection_address", type: "ContractAddress"),
StandardType(name: "token_id", type: "u256"),
],
]

static func getBasicTypes(revision: Revision) -> Set<String> {
switch revision {
case .v0:
basicTypesV0
case .v1:
basicTypesV1
}
}

static func getPresetTypes(revision: Revision) -> [String: [TypeDeclarationWrapper]] {
let types: [String: [any TypeDeclaration]] = switch revision {
case .v0: [:]
case .v1: Self.presetTypesV1
}
return types.mapValues { $0.map { TypeDeclarationWrapper($0) } }
}
}

extension StarknetTypedData {
struct Context: Equatable {
let parent: String
Expand Down

0 comments on commit 00b7939

Please sign in to comment.