From 43050e6d773861320d35aa4bb47eeaa4f7b5ee12 Mon Sep 17 00:00:00 2001
From: cannorin <13620400+cannorin@users.noreply.github.com>
Date: Sun, 22 Sep 2024 17:23:07 +0900
Subject: [PATCH] Monad instance for `Vector` and `Matrix` (#607)
---
src/FSharpPlus.TypeLevel/Data/Matrix.fs | 33 ++++++-
.../FSharpPlus.Tests/FSharpPlus.Tests.fsproj | 1 +
tests/FSharpPlus.Tests/Matrix.fs | 98 +++++++++++++++++++
tests/FSharpPlus.Tests/TypeLevel.fs | 52 +---------
4 files changed, 131 insertions(+), 53 deletions(-)
create mode 100644 tests/FSharpPlus.Tests/Matrix.fs
diff --git a/src/FSharpPlus.TypeLevel/Data/Matrix.fs b/src/FSharpPlus.TypeLevel/Data/Matrix.fs
index b870b2c61..d264101a3 100644
--- a/src/FSharpPlus.TypeLevel/Data/Matrix.fs
+++ b/src/FSharpPlus.TypeLevel/Data/Matrix.fs
@@ -251,6 +251,17 @@ module Vector =
let inline apply (f: Vector<'a -> 'b, 'n>) (v: Vector<'a, 'n>) : Vector<'b, 'n> = map2 id f v
+ ///
+ /// Converts the vector of vectors to a square matrix and returns its diagonal.
+ ///
+ ///
+ []
+ let join (vv: Vector, 'n>): Vector<'a, 'n> =
+ { Items = Array.init (Array.length vv.Items) (fun i -> vv.Items.[i].Items.[i]) }
+
+ let inline bind (f: 'a -> Vector<'b, 'n>) (v: Vector<'a, 'n>) : Vector<'b, 'n> =
+ v |> map f |> join
+
let inline norm (v: Vector< ^a, ^n >) : ^a =
v |> toArray |> Array.sumBy (fun x -> x * x) |> sqrt
let inline maximumNorm (v: Vector< ^a, ^n >) : ^a =
@@ -327,6 +338,20 @@ module Matrix =
for j = 0 to Array2D.length2 m1.Items - 1 do
f i j m1.Items.[i, j] m2.Items.[i, j]
+ let inline apply (f: Matrix<'a -> 'b, 'm, 'n>) (m: Matrix<'a, 'm, 'n>) : Matrix<'b, 'm, 'n> = map2 id f m
+
+ ///
+ /// Converts the matrix of matrices to a 3D cube matrix and returns its diagonal.
+ ///
+ ///
+ []
+ let join (m: Matrix, 'm, 'n>) : Matrix<'a, 'm, 'n> =
+ { Items =
+ Array2D.init (Array2D.length1 m.Items) (Array2D.length2 m.Items)
+ (fun i j -> m.Items.[i, j].Items.[i, j] ) }
+
+ let inline bind (f: 'a -> Matrix<'b, 'm, 'n>) (m: Matrix<'a, 'm, 'n>) : Matrix<'b, 'm, 'n> = m |> map f |> join
+
let inline rowLength (_: Matrix<'a, 'm, 'n>) : 'm = Singleton<'m>
let inline colLength (_: Matrix<'a, 'm, 'n>) : 'n = Singleton<'n>
let inline rowLength' (_: Matrix<'a, ^m, 'n>) : int = RuntimeValue (Singleton< ^m >)
@@ -571,8 +596,10 @@ type Matrix<'Item, 'Row, 'Column> with
static member inline Return (x: 'x) : Matrix<'x, 'm, 'n> = Matrix.replicate Singleton Singleton x
static member inline Pure (x: 'x) : Matrix<'x, 'm, 'n> = Matrix.replicate Singleton Singleton x
- static member inline ( <*> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.map2 id f x
- static member inline ( <.> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.map2 id f x
+ static member inline ( <*> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.apply f x
+ static member inline ( <.> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.apply f x
+ static member inline Join (x: Matrix, 'm, 'n>) = Matrix.join x
+ static member inline ( >>= ) (x: Matrix<'x, 'm, 'n>, f: 'x -> Matrix<'y, 'm, 'n>) = Matrix.bind f x
static member inline get_Zero () : Matrix<'a, 'm, 'n> = Matrix.zero
static member inline ( + ) (m1, m2) = Matrix.map2 (+) m1 m2
static member inline ( - ) (m1, m2) = Matrix.map2 (-) m1 m2
@@ -607,6 +634,8 @@ type Vector<'Item, 'Length> with
static member inline Pure (x: 'x) : Vector<'x, 'n> = Vector.replicate Singleton x
static member inline ( <*> ) (f: Vector<'x -> 'y, 'n>, x: Vector<'x, 'n>) : Vector<'y, 'n> = Vector.apply f x
static member inline ( <.> ) (f: Vector<'x -> 'y, 'n>, x: Vector<'x, 'n>) : Vector<'y, 'n> = Vector.apply f x
+ static member inline Join (x: Vector, 'n>) : Vector<'x, 'n> = Vector.join x
+ static member inline ( >>= ) (x: Vector<'x, 'n>, f: 'x -> Vector<'y, 'n>) = Vector.bind f x
[]
static member inline Zip (x, y) = Vector.zip x y
diff --git a/tests/FSharpPlus.Tests/FSharpPlus.Tests.fsproj b/tests/FSharpPlus.Tests/FSharpPlus.Tests.fsproj
index 4a8e29469..734bf6e45 100644
--- a/tests/FSharpPlus.Tests/FSharpPlus.Tests.fsproj
+++ b/tests/FSharpPlus.Tests/FSharpPlus.Tests.fsproj
@@ -35,6 +35,7 @@
+
diff --git a/tests/FSharpPlus.Tests/Matrix.fs b/tests/FSharpPlus.Tests/Matrix.fs
new file mode 100644
index 000000000..3e0d59b01
--- /dev/null
+++ b/tests/FSharpPlus.Tests/Matrix.fs
@@ -0,0 +1,98 @@
+namespace FSharpPlus.Tests
+
+open System
+open NUnit.Framework
+open Helpers
+
+open FSharpPlus
+open FSharpPlus.Data
+open FSharpPlus.TypeLevel
+
+module VectorTests =
+ []
+ let constructorAndDeconstructorWorks() =
+ let v1 = vector (1,2,3,4,5)
+ let v2 = vector (1,2,3,4,5,6,7,8,9,0,1,2,3,4,5)
+ let (Vector(_,_,_,_,_)) = v1
+ let (Vector(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)) = v2
+ ()
+
+ []
+ let applicativeWorks() =
+ let v = vector ((fun i -> i + 1), (fun i -> i * 2))
+ let u = vector (2, 3)
+ let vu = v <*> u
+ NUnit.Framework.Assert.IsInstanceOf