From 7dc05e8f84b4537cc3e5ccd061095ee99b80aba2 Mon Sep 17 00:00:00 2001 From: Stefan Muenzel Date: Sat, 16 Nov 2024 12:16:08 +0100 Subject: [PATCH 1/3] Add Unsafe base operations --- lib/cstruct.ml | 115 ++++++++++++++++++++++++++++++++---------------- lib/cstruct.mli | 87 +++++++++--------------------------- 2 files changed, 96 insertions(+), 106 deletions(-) diff --git a/lib/cstruct.ml b/lib/cstruct.ml index 5c4b871..78a7780 100644 --- a/lib/cstruct.ml +++ b/lib/cstruct.ml @@ -297,61 +297,98 @@ external swap16 : int -> int = "%bswap16" external swap32 : int32 -> int32 = "%bswap_int32" external swap64 : int64 -> int64 = "%bswap_int64" -let set_uint16 swap p t i c = - if i > t.len - 2 || i < 0 then err_invalid_bounds (p ^ ".set_uint16") t i 2 - else ba_set_int16 t.buffer (t.off+i) (if swap then swap16 c else c) [@@inline] +module type Swap = sig + val swap : bool + val name : string +end -let set_uint32 swap p t i c = - if i > t.len - 4 || i < 0 then err_invalid_bounds (p ^ ".set_uint32") t i 4 - else ba_set_int32 t.buffer (t.off+i) (if swap then swap32 c else c) [@@inline] +module type GetterSetter = sig + val get_uint16: t -> int -> uint16 + val get_uint32: t -> int -> uint32 + val get_uint64: t -> int -> uint64 + val set_uint16: t -> int -> uint16 -> unit + val set_uint32: t -> int -> uint32 -> unit + val set_uint64: t -> int -> uint64 -> unit +end -let set_uint64 swap p t i c = - if i > t.len - 8 || i < 0 then err_invalid_bounds (p ^ ".set_uint64") t i 8 - else ba_set_int64 t.buffer (t.off+i) (if swap then swap64 c else c) [@@inline] +module type GetterSetterWithSwap = sig + module Swap : Swap + include GetterSetter +end -let get_uint16 swap p t i = - if i > t.len - 2 || i < 0 then err_invalid_bounds (p ^ ".get_uint16") t i 2 - else +module Unsafe(Swap : Swap) = struct + module Swap = Swap + + let set_uint16 t i c = + ba_set_int16 t.buffer (t.off+i) (if Swap.swap then swap16 c else c) [@@inline] + + let set_uint32 t i c = + ba_set_int32 t.buffer (t.off+i) (if Swap.swap then swap32 c else c) [@@inline] + + let set_uint64 t i c = + ba_set_int64 t.buffer (t.off+i) (if Swap.swap then swap64 c else c) [@@inline] + + let get_uint16 t i = let r = ba_get_int16 t.buffer (t.off+i) in - if swap then swap16 r else r [@@inline] + if Swap.swap then swap16 r else r [@@inline] -let get_uint32 swap p t i = - if i > t.len - 4 || i < 0 then err_invalid_bounds (p ^ ".get_uint32") t i 4 - else + let get_uint32 t i = let r = ba_get_int32 t.buffer (t.off+i) in - if swap then swap32 r else r [@@inline] + if Swap.swap then swap32 r else r [@@inline] -let get_uint64 swap p t i = - if i > t.len - 8 || i < 0 then err_invalid_bounds (p ^ ".get_uint64") t i 8 - else + let get_uint64 t i = let r = ba_get_int64 t.buffer (t.off+i) in - if swap then swap64 r else r [@@inline] + if Swap.swap then swap64 r else r [@@inline] +end [@@inline] + +module Safe(Unsafe : GetterSetterWithSwap) = struct + module Unsafe = Unsafe + module Swap = Unsafe.Swap + + let set_uint16 t i c = + if i > t.len - 2 || i < 0 then err_invalid_bounds (Swap.name ^ ".set_uint16") t i 2 + else Unsafe.set_uint16 t i c [@@inline] + + let set_uint32 t i c = + if i > t.len - 4 || i < 0 then err_invalid_bounds (Swap.name ^ ".set_uint32") t i 4 + else Unsafe.set_uint32 t i c [@@inline] + + let set_uint64 t i c = + if i > t.len - 8 || i < 0 then err_invalid_bounds (Swap.name ^ ".set_uint64") t i 8 + else Unsafe.set_uint64 t i c [@@inline] + + let get_uint16 t i = + if i > t.len - 2 || i < 0 then err_invalid_bounds (Swap.name ^ ".get_uint16") t i 2 + else Unsafe.get_uint16 t i [@@inline] + + let get_uint32 t i = + if i > t.len - 4 || i < 0 then err_invalid_bounds (Swap.name ^ ".get_uint32") t i 4 + else Unsafe.get_uint32 t i [@@inline] + + let get_uint64 t i = + if i > t.len - 8 || i < 0 then err_invalid_bounds (Swap.name ^ ".get_uint64") t i 8 + else Unsafe.get_uint64 t i [@@inline] +end [@@inline] module BE = struct - let set_uint16 t i c = set_uint16 (not Sys.big_endian) "BE" t i c [@@inline] - let set_uint32 t i c = set_uint32 (not Sys.big_endian) "BE" t i c [@@inline] - let set_uint64 t i c = set_uint64 (not Sys.big_endian) "BE" t i c [@@inline] - let get_uint16 t i = get_uint16 (not Sys.big_endian) "BE" t i [@@inline] - let get_uint32 t i = get_uint32 (not Sys.big_endian) "BE" t i [@@inline] - let get_uint64 t i = get_uint64 (not Sys.big_endian) "BE" t i [@@inline] + include Safe(Unsafe(struct + let swap = not Sys.big_endian + let name = "BE" + end)) end module LE = struct - let set_uint16 t i c = set_uint16 Sys.big_endian "LE" t i c [@@inline] - let set_uint32 t i c = set_uint32 Sys.big_endian "LE" t i c [@@inline] - let set_uint64 t i c = set_uint64 Sys.big_endian "LE" t i c [@@inline] - let get_uint16 t i = get_uint16 Sys.big_endian "LE" t i [@@inline] - let get_uint32 t i = get_uint32 Sys.big_endian "LE" t i [@@inline] - let get_uint64 t i = get_uint64 Sys.big_endian "LE" t i [@@inline] + include Safe(Unsafe(struct + let swap = Sys.big_endian + let name = "LE" + end)) end module HE = struct - let set_uint16 t i c = set_uint16 false "HE" t i c [@@inline] - let set_uint32 t i c = set_uint32 false "HE" t i c [@@inline] - let set_uint64 t i c = set_uint64 false "HE" t i c [@@inline] - let get_uint16 t i = get_uint16 false "HE" t i [@@inline] - let get_uint32 t i = get_uint32 false "HE" t i [@@inline] - let get_uint64 t i = get_uint64 false "HE" t i [@@inline] + include Safe(Unsafe(struct + let swap = false + let name = "HE" + end)) end let length { len ; _ } = len diff --git a/lib/cstruct.mli b/lib/cstruct.mli index fd5b081..336d62a 100644 --- a/lib/cstruct.mli +++ b/lib/cstruct.mli @@ -372,78 +372,57 @@ val to_bytes: ?off:int -> ?len:int -> t -> bytes @raise Invalid_argument if [off] or [len] is negative, or [Cstruct.length str - off] < [len]. *) -module BE : sig - - (** Get/set big-endian integers of various sizes. The second - argument of those functions is the position relative to the - current offset of the cstruct. *) - +module type GetterSetter = sig val get_uint16: t -> int -> uint16 - (** [get_uint16 cstr off] is the 16 bit long big-endian unsigned + (** [get_uint16 cstr off] is the 16 bit long unsigned integer stored in [cstr] at offset [off]. @raise Invalid_argument if the buffer is too small. *) val get_uint32: t -> int -> uint32 - (** [get_uint32 cstr off] is the 32 bit long big-endian unsigned + (** [get_uint32 cstr off] is the 32 bit long unsigned integer stored in [cstr] at offset [off]. @raise Invalid_argument if the buffer is too small. *) val get_uint64: t -> int -> uint64 - (** [get_uint64 cstr off] is the 64 bit long big-endian unsigned + (** [get_uint64 cstr off] is the 64 bit long unsigned integer stored in [cstr] at offset [off]. @raise Invalid_argument if the buffer is too small. *) val set_uint16: t -> int -> uint16 -> unit - (** [set_uint16 cstr off i] writes the 16 bit long big-endian + (** [set_uint16 cstr off i] writes the 16 bit long unsigned integer [i] at offset [off] of [cstr]. @raise Invalid_argument if the buffer is too small. *) val set_uint32: t -> int -> uint32 -> unit - (** [set_uint32 cstr off i] writes the 32 bit long big-endian + (** [set_uint32 cstr off i] writes the 32 bit long unsigned integer [i] at offset [off] of [cstr]. @raise Invalid_argument if the buffer is too small. *) val set_uint64: t -> int -> uint64 -> unit - (** [set_uint64 cstr off i] writes the 64 bit long big-endian + (** [set_uint64 cstr off i] writes the 64 bit long unsigned integer [i] at offset [off] of [cstr]. @raise Invalid_argument if the buffer is too small. *) end -module LE : sig +module BE : sig - (** Get/set little-endian integers of various sizes. The second + (** Get/set big-endian integers of various sizes. The second argument of those functions is the position relative to the current offset of the cstruct. *) + include GetterSetter - val get_uint16: t -> int -> uint16 - (** [get_uint16 cstr off] is the 16 bit long little-endian unsigned - integer stored in [cstr] at offset [off]. - @raise Invalid_argument if the buffer is too small. *) - - val get_uint32: t -> int -> uint32 - (** [get_uint32 cstr off] is the 32 bit long little-endian unsigned - integer stored in [cstr] at offset [off]. - @raise Invalid_argument if the buffer is too small. *) + module Unsafe : GetterSetter +end - val get_uint64: t -> int -> uint64 - (** [get_uint64 cstr off] is the 64 bit long little-endian unsigned - integer stored in [cstr] at offset [off]. - @raise Invalid_argument if the buffer is too small. *) +module LE : sig - val set_uint16: t -> int -> uint16 -> unit - (** [set_uint16 cstr off i] writes the 16 bit long little-endian - unsigned integer [i] at offset [off] of [cstr]. - @raise Invalid_argument if the buffer is too small. *) + (** Get/set little-endian integers of various sizes. The second + argument of those functions is the position relative to the + current offset of the cstruct. *) - val set_uint32: t -> int -> uint32 -> unit - (** [set_uint32 cstr off i] writes the 32 bit long little-endian - unsigned integer [i] at offset [off] of [cstr]. - @raise Invalid_argument if the buffer is too small. *) + include GetterSetter - val set_uint64: t -> int -> uint64 -> unit - (** [set_uint64 cstr off i] writes the 64 bit long little-endian - unsigned integer [i] at offset [off] of [cstr]. - @raise Invalid_argument if the buffer is too small. *) + module Unsafe : GetterSetter end module HE : sig @@ -452,35 +431,9 @@ module HE : sig argument of those functions is the position relative to the current offset of the cstruct. *) - val get_uint16: t -> int -> uint16 - (** [get_uint16 cstr off] is the 16 bit long host-endian unsigned - integer stored in [cstr] at offset [off]. - @raise Invalid_argument if the buffer is too small. *) - - val get_uint32: t -> int -> uint32 - (** [get_uint32 cstr off] is the 32 bit long host-endian unsigned - integer stored in [cstr] at offset [off]. - @raise Invalid_argument if the buffer is too small. *) - - val get_uint64: t -> int -> uint64 - (** [get_uint64 cstr off] is the 64 bit long host-endian unsigned - integer stored in [cstr] at offset [off]. - @raise Invalid_argument if the buffer is too small. *) - - val set_uint16: t -> int -> uint16 -> unit - (** [set_uint16 cstr off i] writes the 16 bit long host-endian - unsigned integer [i] at offset [off] of [cstr]. - @raise Invalid_argument if the buffer is too small. *) + include GetterSetter - val set_uint32: t -> int -> uint32 -> unit - (** [set_uint32 cstr off i] writes the 32 bit long host-endian - unsigned integer [i] at offset [off] of [cstr]. - @raise Invalid_argument if the buffer is too small. *) - - val set_uint64: t -> int -> uint64 -> unit - (** [set_uint64 cstr off i] writes the 64 bit long host-endian - unsigned integer [i] at offset [off] of [cstr]. - @raise Invalid_argument if the buffer is too small. *) + module Unsafe : GetterSetter end (** {2 Debugging } *) From 9581112ca0e9f0cfe7b03d4194a2cdc56ce8b543 Mon Sep 17 00:00:00 2001 From: Stefan Muenzel Date: Sat, 16 Nov 2024 14:54:51 +0100 Subject: [PATCH 2/3] More unsafe ops --- lib/cstruct.ml | 120 +++++++++++++++++++++++++----------------- lib/cstruct.mli | 55 +++++++++++--------- ppx/ppx_cstruct.ml | 126 ++++++++++++++++++++++++++++++++++----------- 3 files changed, 198 insertions(+), 103 deletions(-) diff --git a/lib/cstruct.ml b/lib/cstruct.ml index 78a7780..1d2ac5d 100644 --- a/lib/cstruct.ml +++ b/lib/cstruct.ml @@ -268,23 +268,42 @@ let create len = let t = create_unsafe len in memset t 0; t +module type GetterSetterByte = sig + val get_char: t -> int -> char + val get_uint8: t -> int -> uint8 + val set_char: t -> int -> char -> unit + val set_uint8: t -> int -> uint8 -> unit +end + +module Unsafe = struct + let set_uint8 t i c = + Bigarray.Array1.unsafe_set t.buffer (t.off+i) (Char.unsafe_chr c) + + let set_char t i c = + Bigarray.Array1.unsafe_set t.buffer (t.off+i) c + + let get_uint8 t i = + Char.code (Bigarray.Array1.unsafe_get t.buffer (t.off+i)) + + let get_char t i = + Bigarray.Array1.unsafe_get t.buffer (t.off+i) +end let set_uint8 t i c = if i >= t.len || i < 0 then err_invalid_bounds "set_uint8" t i 1 - else Bigarray.Array1.set t.buffer (t.off+i) (Char.unsafe_chr c) + else Unsafe.set_uint8 t i c let set_char t i c = if i >= t.len || i < 0 then err_invalid_bounds "set_char" t i 1 - else Bigarray.Array1.set t.buffer (t.off+i) c + else Unsafe.set_char t i c let get_uint8 t i = if i >= t.len || i < 0 then err_invalid_bounds "get_uint8" t i 1 - else Char.code (Bigarray.Array1.get t.buffer (t.off+i)) + else Unsafe.get_uint8 t i let get_char t i = if i >= t.len || i < 0 then err_invalid_bounds "get_char" t i 1 - else Bigarray.Array1.get t.buffer (t.off+i) - + else Unsafe.get_char t i external ba_set_int16 : buffer -> int -> uint16 -> unit = "%caml_bigstring_set16u" external ba_set_int32 : buffer -> int -> uint32 -> unit = "%caml_bigstring_set32u" @@ -302,7 +321,7 @@ module type Swap = sig val name : string end -module type GetterSetter = sig +module type GetterSetterMultiByte = sig val get_uint16: t -> int -> uint16 val get_uint32: t -> int -> uint32 val get_uint64: t -> int -> uint64 @@ -311,66 +330,69 @@ module type GetterSetter = sig val set_uint64: t -> int -> uint64 -> unit end -module type GetterSetterWithSwap = sig +module type GetterSetterMultiByteWithSwap = sig module Swap : Swap - include GetterSetter + include GetterSetterMultiByte end -module Unsafe(Swap : Swap) = struct - module Swap = Swap +module Internal = struct + module Unsafe(Swap : Swap) = struct + module Swap = Swap - let set_uint16 t i c = - ba_set_int16 t.buffer (t.off+i) (if Swap.swap then swap16 c else c) [@@inline] + let set_uint16 t i c = + ba_set_int16 t.buffer (t.off+i) (if Swap.swap then swap16 c else c) [@@inline] - let set_uint32 t i c = - ba_set_int32 t.buffer (t.off+i) (if Swap.swap then swap32 c else c) [@@inline] + let set_uint32 t i c = + ba_set_int32 t.buffer (t.off+i) (if Swap.swap then swap32 c else c) [@@inline] - let set_uint64 t i c = - ba_set_int64 t.buffer (t.off+i) (if Swap.swap then swap64 c else c) [@@inline] + let set_uint64 t i c = + ba_set_int64 t.buffer (t.off+i) (if Swap.swap then swap64 c else c) [@@inline] - let get_uint16 t i = - let r = ba_get_int16 t.buffer (t.off+i) in - if Swap.swap then swap16 r else r [@@inline] + let get_uint16 t i = + let r = ba_get_int16 t.buffer (t.off+i) in + if Swap.swap then swap16 r else r [@@inline] - let get_uint32 t i = - let r = ba_get_int32 t.buffer (t.off+i) in - if Swap.swap then swap32 r else r [@@inline] + let get_uint32 t i = + let r = ba_get_int32 t.buffer (t.off+i) in + if Swap.swap then swap32 r else r [@@inline] - let get_uint64 t i = - let r = ba_get_int64 t.buffer (t.off+i) in - if Swap.swap then swap64 r else r [@@inline] -end [@@inline] + let get_uint64 t i = + let r = ba_get_int64 t.buffer (t.off+i) in + if Swap.swap then swap64 r else r [@@inline] + end [@@inline] -module Safe(Unsafe : GetterSetterWithSwap) = struct - module Unsafe = Unsafe - module Swap = Unsafe.Swap + module Safe(Unsafe : GetterSetterMultiByteWithSwap) = struct + module Unsafe = Unsafe + module Swap = Unsafe.Swap - let set_uint16 t i c = - if i > t.len - 2 || i < 0 then err_invalid_bounds (Swap.name ^ ".set_uint16") t i 2 - else Unsafe.set_uint16 t i c [@@inline] + let set_uint16 t i c = + if i > t.len - 2 || i < 0 then err_invalid_bounds (Swap.name ^ ".set_uint16") t i 2 + else Unsafe.set_uint16 t i c [@@inline] - let set_uint32 t i c = - if i > t.len - 4 || i < 0 then err_invalid_bounds (Swap.name ^ ".set_uint32") t i 4 - else Unsafe.set_uint32 t i c [@@inline] + let set_uint32 t i c = + if i > t.len - 4 || i < 0 then err_invalid_bounds (Swap.name ^ ".set_uint32") t i 4 + else Unsafe.set_uint32 t i c [@@inline] - let set_uint64 t i c = - if i > t.len - 8 || i < 0 then err_invalid_bounds (Swap.name ^ ".set_uint64") t i 8 - else Unsafe.set_uint64 t i c [@@inline] + let set_uint64 t i c = + if i > t.len - 8 || i < 0 then err_invalid_bounds (Swap.name ^ ".set_uint64") t i 8 + else Unsafe.set_uint64 t i c [@@inline] - let get_uint16 t i = - if i > t.len - 2 || i < 0 then err_invalid_bounds (Swap.name ^ ".get_uint16") t i 2 - else Unsafe.get_uint16 t i [@@inline] + let get_uint16 t i = + if i > t.len - 2 || i < 0 then err_invalid_bounds (Swap.name ^ ".get_uint16") t i 2 + else Unsafe.get_uint16 t i [@@inline] - let get_uint32 t i = - if i > t.len - 4 || i < 0 then err_invalid_bounds (Swap.name ^ ".get_uint32") t i 4 - else Unsafe.get_uint32 t i [@@inline] + let get_uint32 t i = + if i > t.len - 4 || i < 0 then err_invalid_bounds (Swap.name ^ ".get_uint32") t i 4 + else Unsafe.get_uint32 t i [@@inline] - let get_uint64 t i = - if i > t.len - 8 || i < 0 then err_invalid_bounds (Swap.name ^ ".get_uint64") t i 8 - else Unsafe.get_uint64 t i [@@inline] -end [@@inline] + let get_uint64 t i = + if i > t.len - 8 || i < 0 then err_invalid_bounds (Swap.name ^ ".get_uint64") t i 8 + else Unsafe.get_uint64 t i [@@inline] + end [@@inline] +end module BE = struct + open Internal include Safe(Unsafe(struct let swap = not Sys.big_endian let name = "BE" @@ -378,6 +400,7 @@ module BE = struct end module LE = struct + open Internal include Safe(Unsafe(struct let swap = Sys.big_endian let name = "LE" @@ -385,6 +408,7 @@ module LE = struct end module HE = struct + open Internal include Safe(Unsafe(struct let swap = false let name = "HE" diff --git a/lib/cstruct.mli b/lib/cstruct.mli index 336d62a..ba5cd68 100644 --- a/lib/cstruct.mli +++ b/lib/cstruct.mli @@ -258,25 +258,30 @@ val check_alignment : t -> int -> bool boundary. @raise Invalid_argument if [alignment] is not a positive integer. *) -val get_char: t -> int -> char -(** [get_char t off] returns the character contained in the cstruct - at offset [off]. - @raise Invalid_argument if the offset exceeds cstruct length. *) - -val get_uint8: t -> int -> uint8 -(** [get_uint8 t off] returns the byte contained in the cstruct - at offset [off]. - @raise Invalid_argument if the offset exceeds cstruct length. *) - -val set_char: t -> int -> char -> unit -(** [set_char t off c] sets the byte contained in the cstruct - at offset [off] to character [c]. - @raise Invalid_argument if the offset exceeds cstruct length. *) +module type GetterSetterByte = sig + val get_char: t -> int -> char + (** [get_char t off] returns the character contained in the cstruct + at offset [off]. + @raise Invalid_argument if the offset exceeds cstruct length. *) + + val get_uint8: t -> int -> uint8 + (** [get_uint8 t off] returns the byte contained in the cstruct + at offset [off]. + @raise Invalid_argument if the offset exceeds cstruct length. *) + + val set_char: t -> int -> char -> unit + (** [set_char t off c] sets the byte contained in the cstruct + at offset [off] to character [c]. + @raise Invalid_argument if the offset exceeds cstruct length. *) + + val set_uint8: t -> int -> uint8 -> unit + (** [set_uint8 t off c] sets the byte contained in the cstruct + at offset [off] to byte [c]. + @raise Invalid_argument if the offset exceeds cstruct length. *) +end +include GetterSetterByte -val set_uint8: t -> int -> uint8 -> unit -(** [set_uint8 t off c] sets the byte contained in the cstruct - at offset [off] to byte [c]. - @raise Invalid_argument if the offset exceeds cstruct length. *) +module Unsafe : GetterSetterByte val sub: t -> int -> int -> t (** [sub cstr off len] is [{ t with off = t.off + off; len }] @@ -372,7 +377,7 @@ val to_bytes: ?off:int -> ?len:int -> t -> bytes @raise Invalid_argument if [off] or [len] is negative, or [Cstruct.length str - off] < [len]. *) -module type GetterSetter = sig +module type GetterSetterMultiByte = sig val get_uint16: t -> int -> uint16 (** [get_uint16 cstr off] is the 16 bit long unsigned integer stored in [cstr] at offset [off]. @@ -409,9 +414,9 @@ module BE : sig (** Get/set big-endian integers of various sizes. The second argument of those functions is the position relative to the current offset of the cstruct. *) - include GetterSetter + include GetterSetterMultiByte - module Unsafe : GetterSetter + module Unsafe : GetterSetterMultiByte end module LE : sig @@ -420,9 +425,9 @@ module LE : sig argument of those functions is the position relative to the current offset of the cstruct. *) - include GetterSetter + include GetterSetterMultiByte - module Unsafe : GetterSetter + module Unsafe : GetterSetterMultiByte end module HE : sig @@ -431,9 +436,9 @@ module HE : sig argument of those functions is the position relative to the current offset of the cstruct. *) - include GetterSetter + include GetterSetterMultiByte - module Unsafe : GetterSetter + module Unsafe : GetterSetterMultiByte end (** {2 Debugging } *) diff --git a/ppx/ppx_cstruct.ml b/ppx/ppx_cstruct.ml index 3525db8..f4044d8 100644 --- a/ppx/ppx_cstruct.ml +++ b/ppx/ppx_cstruct.ml @@ -41,6 +41,8 @@ end type mode = Big_endian | Little_endian | Host_endian | Bi_endian +type safety = Safe | Unsafe + type prim = | Char | UInt8 @@ -169,14 +171,26 @@ let create_struct loc endian name fields = let ($.) l x = Longident.Ldot (l, x) let cstruct_id = Longident.Lident "Cstruct" -let mode_mod s = function - |Big_endian -> cstruct_id$."BE"$.s - |Little_endian -> cstruct_id$."LE"$.s - |Host_endian -> cstruct_id$."HE"$.s - |Bi_endian -> cstruct_id$."BL"$.s -let mode_mod loc x s = - Exp.ident ~loc {loc ; txt = mode_mod s x} +let mode_module = function + |Big_endian -> cstruct_id$."BE" + |Little_endian -> cstruct_id$."LE" + |Host_endian -> cstruct_id$."HE" + |Bi_endian -> cstruct_id$."BL" + +let safe_module safety parent = + match safety with + | Safe -> parent + | Unsafe -> parent$."Unsafe" + +let mode_mod s mode safety = + (safe_module safety (mode_module mode))$.s + +let mode_mod loc x safety s = + Exp.ident ~loc {loc ; txt = mode_mod s x safety} + +let safe_mod loc safety s = + Exp.ident ~loc { loc; txt = (safe_module safety cstruct_id)$.s } type op = | Op_get of named_field @@ -203,24 +217,28 @@ let op_name s op = let op_pvar ~loc s op = Ast.pvar ~loc (op_name s op) let op_evar ~loc s op = Ast.evar ~loc (op_name s op) -let get_expr loc s f = - let m = mode_mod loc s.endian in +let get_expr_prim loc s safety prim off = + let m = mode_mod loc s.endian safety in let num x = Ast.eint ~loc x in + [%expr + fun v -> + [%e match prim with + |Char -> [%expr [%e safe_mod loc safety "get_char"] v [%e num off]] + |UInt8 -> [%expr [%e safe_mod loc safety "get_uint8"] v [%e num off]] + |UInt16 -> [%expr [%e m "get_uint16"] v [%e num off]] + |UInt32 -> [%expr [%e m "get_uint32"] v [%e num off]] + |UInt64 -> [%expr [%e m "get_uint64"] v [%e num off]]]] + +let get_expr loc s f = match f.ty with |Buffer (_, _) -> + let num x = Ast.eint ~loc x in let len = width_of_field f in [%expr fun src -> Cstruct.sub src [%e num f.off] [%e num len] ] |Prim prim -> - [%expr - fun v -> - [%e match prim with - |Char -> [%expr Cstruct.get_char v [%e num f.off]] - |UInt8 -> [%expr Cstruct.get_uint8 v [%e num f.off]] - |UInt16 -> [%expr [%e m "get_uint16"] v [%e num f.off]] - |UInt32 -> [%expr [%e m "get_uint32"] v [%e num f.off]] - |UInt64 -> [%expr [%e m "get_uint64"] v [%e num f.off]]]] + get_expr_prim loc s Safe prim f.off let type_of_int_field ~loc = function |Char -> [%type: char] @@ -229,8 +247,18 @@ let type_of_int_field ~loc = function |UInt32 -> [%type: Cstruct.uint32] |UInt64 -> [%type: Cstruct.uint64] +let set_expr_prim loc s safety prim off = + let m = mode_mod loc s.endian safety in + let num x = Ast.eint ~loc x in + [%expr fun v x -> + [%e match prim with + |Char -> [%expr [%e safe_mod loc safety "set_char"] v [%e num off] x] + |UInt8 -> [%expr [%e safe_mod loc safety "set_uint8"] v [%e num off] x] + |UInt16 -> [%expr [%e m "set_uint16"] v [%e num off] x] + |UInt32 -> [%expr [%e m "set_uint32"] v [%e num off] x] + |UInt64 -> [%expr [%e m "set_uint64"] v [%e num off] x]]] + let set_expr loc s f = - let m = mode_mod loc s.endian in let num x = Ast.eint ~loc x in match f.ty with |Buffer (_,_) -> @@ -239,13 +267,7 @@ let set_expr loc s f = fun src srcoff dst -> Cstruct.blit_from_string src srcoff dst [%e num f.off] [%e num len]] |Prim prim -> - [%expr fun v x -> - [%e match prim with - |Char -> [%expr Cstruct.set_char v [%e num f.off] x] - |UInt8 -> [%expr Cstruct.set_uint8 v [%e num f.off] x] - |UInt16 -> [%expr [%e m "set_uint16"] v [%e num f.off] x] - |UInt32 -> [%expr [%e m "set_uint32"] v [%e num f.off] x] - |UInt64 -> [%expr [%e m "set_uint64"] v [%e num f.off] x]]] + set_expr_prim loc s Safe prim f.off let type_of_set ~loc f = match f.ty with @@ -328,13 +350,57 @@ let ops_for s = Op_hexdump; ]) +let make_unsafe_op = function + | Op_sizeof -> None + | Op_blit _ -> None + | Op_copy _ -> None + | Op_hexdump -> None + | Op_hexdump_to_buffer -> None + | Op_get { ty = Prim _ ; _} as op-> Some op + | Op_set { ty = Prim _ ; _ } as op -> Some op + | Op_get _ -> None + | Op_set _ -> None + (** Generate functions of the form {get/set}__ *) let output_struct_one_endian loc s = - List.map - (fun op -> - [%stri let[@ocaml.warning "-32"] [%p op_pvar ~loc s op] = - [%e op_expr loc s op]]) - (ops_for s) + let ops = ops_for s in + let safe_ops = + List.map + (fun op -> + [%stri let[@ocaml.warning "-32"] [%p op_pvar ~loc s op] = + [%e op_expr loc s op]]) + ops + in + let unsafe_module = + let ops = + List.filter_map make_unsafe_op ops + |> List.map + (fun op -> + let expr = + match op with + | Op_get { ty = Prim prim; off; _ } -> + get_expr_prim loc s Unsafe prim off + | Op_set { ty = Prim prim; off; _ } -> + set_expr_prim loc s Unsafe prim off + | _ -> + loc_err loc "Unsupported unsafe op" + in + [%stri let[@ocaml.warning "-32"] [%p op_pvar ~loc s op] = + [%e expr]] + ) + in + let expr = + Mod.structure ops + in + let modname = + "Unsafe_accessors_" ^ s.name + in + {pstr_desc = Pstr_module + {pmb_name = {txt = Some modname; loc}; pmb_expr = expr ; + pmb_attributes = []; pmb_loc = loc;}; pstr_loc = loc;} + in + unsafe_module :: + safe_ops let output_struct _loc s = match s.endian with From 9ea4162429e2cc59d8e0fa272232e06fd145621f Mon Sep 17 00:00:00 2001 From: Stefan Muenzel Date: Sat, 16 Nov 2024 15:07:05 +0100 Subject: [PATCH 3/3] ocaml-migrate-parsetree is not required --- ppx_cstruct.opam | 1 - ppx_test/errors/dune | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/ppx_cstruct.opam b/ppx_cstruct.opam index 264ad41..5c622fc 100644 --- a/ppx_cstruct.opam +++ b/ppx_cstruct.opam @@ -26,7 +26,6 @@ depends: [ "cstruct-sexp" {with-test} "cppo" {with-test} "cstruct-unix" {with-test & =version} - "ocaml-migrate-parsetree" {>= "2.1.0" & with-test} "lwt_ppx" {>= "2.0.2" & with-test} ] synopsis: "Access C-like structures directly from OCaml" diff --git a/ppx_test/errors/dune b/ppx_test/errors/dune index ed51ce0..0712dc9 100644 --- a/ppx_test/errors/dune +++ b/ppx_test/errors/dune @@ -4,7 +4,7 @@ (preprocess (action (run %{bin:cppo} -V OCAML:%{ocaml_version} %{input-file}))) - (libraries ppx_cstruct ocaml-migrate-parsetree)) + (libraries ppx_cstruct)) (executable (name gen_tests)