Skip to content

Commit

Permalink
update some templates
Browse files Browse the repository at this point in the history
  • Loading branch information
jcosborn committed Nov 15, 2024
1 parent 57e70af commit 941a21c
Showing 1 changed file with 73 additions and 69 deletions.
142 changes: 73 additions & 69 deletions src/physics/tensorwrap.nim
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,6 @@ forwardTT(numberType)
forwardTTW(toSingle)
forwardTTW(toDouble)

template row*(x: SomeTensor, i: auto): auto =
mixin row
tensorObj(x.kind, row(x[],i))
template setRow*[K,T1,T2](r: SomeTensor[K,T1]; x: SomeTensor2[K,T2]; i: auto): auto =
setRow(r[], x[], i)

#** Unary assignment functions
template setUnaryAssignT(f) {.dirty.} =
template f*[R:SomeTensor,X:SomeTensor2](r: var R, x: X) =
Expand Down Expand Up @@ -172,10 +166,10 @@ setUnaryAssignX(imul, AsComplex)

#** Unary functions
template setUnopT(f, g) {.dirty.} =
template f*(x: typedesc[SomeTensor]): typedesc =
tensorObj(x.kind, f(x[]))
template f*(x: SomeTensor): auto =
tensorObj(x.kind, f(x[]))
template f*[X:SomeTensor](x: typedesc[X]): typedesc =
tensorObj(X.kind, f(X[]))
template f*[X:SomeTensor](x: X): auto =
tensorObj(X.kind, f(x[]))

setUnopT(load1, load1)
setUnopT(`-`, neg)
Expand All @@ -185,7 +179,7 @@ setUnopT(simdSum, simdSum)

#** Binary assignment functions

template setBinAssignTT(f) {.dirty.} =
template setBinopAssignTT(f) {.dirty.} =
template f*[R:SomeTensor,X:SomeTensor2,Y:SomeTensor3](r: var R, x: X, y: Y) =
when R.kind is X.kind:
when R.kind is Y.kind:
Expand All @@ -197,74 +191,74 @@ template setBinAssignTT(f) {.dirty.} =
f(r[], x, y[])
else:
f(r[], x, y)
template setBinAssignTX(f,Y) {.dirty.} =
template setBinopAssignTX(f,Y) {.dirty.} =
template f*[R:SomeTensor,X:SomeTensor2](r: var R, x: X, y: Y) =
when R.kind is X.kind:
f(r[], x[], y)
else:
f(r[], x, y)
template setBinAssignXT(f,X) {.dirty.} =
template setBinopAssignXT(f,X) {.dirty.} =
template f*[R:SomeTensor,Y:SomeTensor3](r: var R, x: X, y: Y) =
when R.kind is Y.kind:
f(r[], x, y[])
else:
f(r[], x, y)

setBinAssignTT(add)
setBinAssignTT(sub)
setBinopAssignTT(add)
setBinopAssignTT(sub)

setBinAssignTT(mul)
setBinAssignTX(mul, SomeNumber)
setBinAssignXT(mul, SomeNumber)
setBinAssignTX(mul, AsComplex)
setBinAssignXT(mul, AsComplex)
setBinopAssignTT(mul)
setBinopAssignTX(mul, SomeNumber)
setBinopAssignXT(mul, SomeNumber)
setBinopAssignTX(mul, AsComplex)
setBinopAssignXT(mul, AsComplex)

setBinAssignTT(imadd)
setBinAssignTX(imadd, AsComplex)
setBinAssignXT(imadd, AsComplex)
setBinopAssignTT(imadd)
setBinopAssignTX(imadd, AsComplex)
setBinopAssignXT(imadd, AsComplex)

setBinAssignTT(imsub)
setBinopAssignTT(imsub)

setBinAssignTT(peqOuter)
setBinAssignTT(meqOuter)
setBinopAssignTT(peqOuter)
setBinopAssignTT(meqOuter)

#** Binary functions

template setBinopTT(f, g) {.dirty.} =
template f*(x: typedesc[SomeTensor], y: typedesc[SomeTensor2]): typedesc =
when x.kind is y.kind:
tensorObj(x.kind, f(x[],y[]))
elif x.kind.has y.kind:
f(x[], y)
else: # assume y.kind has x.kind
f(x, y[])
template f*(x: SomeTensor, y: SomeTensor2): auto =
when x.kind is y.kind:
tensorObj(x.kind, f(x[],y[]))
elif x.kind.has y.kind:
template f*[X:SomeTensor,Y:SomeTensor2](x: typedesc[X], y: typedesc[Y]): typedesc =
when X.kind is Y.kind:
tensorObj(X.kind, f(X[],Y[]))
elif X.kind.has Y.kind:
f(X[], Y)
else: # assume Y.kind has X.kind
f(X, Y[])
template f*[X:SomeTensor,Y:SomeTensor2](x: X, y: Y): auto =
when X.kind is Y.kind:
tensorObj(X.kind, f(x[],y[]))
elif X.kind.has Y.kind:
f(x[], y)
else: # assume y.kind has x.kind
else: # assume Y.kind has X.kind
f(x, y[])
#template f*(x: SomeTensor, y: SomeTensor2): auto =
# var r {.noInit.}: f(x.type, y.type)
# g(r, x, y)
# r
template setBinopTX(f,g,X) {.dirty.} =
template f*(x: typedesc[SomeTensor], y: typedesc[X]): typedesc =
tensorObj(x.kind, f(x[], y))
template f*(x: SomeTensor, y: X): auto =
tensorObj(x.kind, f(x[], y))
template setBinopTX(f,g,Y) {.dirty.} =
template f*[X:SomeTensor](x: typedesc[X], y: typedesc[Y]): typedesc =
tensorObj(X.kind, f(X[], Y))
template f*[X:SomeTensor](x: X, y: Y): auto =
tensorObj(X.kind, f(x[], y))
#template f*(x: SomeTensor, y: X): auto =
# var r {.noInit.}: f(x.type, X)
# g(r, x, y)
# r
template setBinopXT(f,g,X) {.dirty.} =
template f*(x: typedesc[X], y: typedesc[SomeTensor]): typedesc =
tensorObj(y.kind, f(x, y[]))
template f*(x: X, y: SomeTensor): auto =
tensorObj(y.kind, f(x, y[]))
#template f*(x: X, y: SomeTensor): auto =
# var r {.noInit.}: f(X, y.type)
template f*[Y:SomeTensor](x: typedesc[X], y: typedesc[Y]): typedesc =
tensorObj(Y.kind, f(X, Y[]))
template f*[Y:SomeTensor](x: X, y: Y): auto =
tensorObj(Y.kind, f(x, y[]))
#template f*[Y:SomeTensor](x: X, y: Y): auto =
# var r {.noInit.}: f(X, Y)
# g(r, x, y)
# r

Expand All @@ -281,12 +275,12 @@ setBinopTX(`*`, mul, SomeNumber)
setBinopXT(`*`, mul, SomeNumber)
setBinopTX(`*`, mul, Simd)
setBinopXT(`*`, mul, Simd)
setBinopXT(`*`, mul, AsReal)
setBinopTX(`*`, mul, AsReal)
setBinopXT(`*`, mul, AsImag)
setBinopXT(`*`, mul, AsReal)
setBinopTX(`*`, mul, AsImag)
setBinopXT(`*`, mul, AsComplex)
setBinopXT(`*`, mul, AsImag)
setBinopTX(`*`, mul, AsComplex)
setBinopXT(`*`, mul, AsComplex)

#[
#template mul*(x: SomeNumber, y: SomeTensor2): auto =
Expand All @@ -302,44 +296,54 @@ template mul*(x: AsComplex, y: SomeTensor2): auto =
asSomeTensor(mul(x, y[]))
]#

#template row*[K,T](x: SomeTensor[K,T], i: auto): auto =
# mixin row
# tensorObj(x.kind, row(x[],i))
#template setRow*[K,T1,T2](r: SomeTensor[K,T1]; x: SomeTensor2[K,T2]; i: auto) =
# setRow(r[], x[], i)
setBinopAssignTX(setRow, auto)
setBinopTX(row, setRow, auto)

#template random*(x: var SomeTensor) = gaussian(x[], r)
setUnaryAssignX(gaussian, auto)
setUnaryAssignX(uniform, auto)
setUnaryAssignX(z2, auto)
setUnaryAssignX(z4, auto)
setUnaryAssignX(u1, auto)

template projectU*(r: var SomeTensor) =
template projectU*(r: var SomeTensor) = # in place
projectU(r[])
template projectU*(r: var SomeTensor, x: SomeTensor2) =
projectU(r[], x[])
template projectUderiv*(r: var SomeTensor, u: SomeTensor2, x: SomeTensor3, chain: SomeTensor4) =
#template projectU*(r: var SomeTensor, x: SomeTensor2) =
# projectU(r[], x[])
setUnaryAssignT(projectU)

template projectUderiv*(r: var SomeTensor, u: SomeTensor2,
x: SomeTensor3, chain: SomeTensor4) =
projectUderiv(r[], u[], x[], chain[])
template projectUderiv*(r: var SomeTensor, x: SomeTensor3, chain: SomeTensor4) =
projectUderiv(r[], x[], chain[])
template projectSU*(r: var SomeTensor) =
projectSU(r[])
template projectSU*(r: var SomeTensor, x: SomeTensor) =
template projectSU*(r: var SomeTensor, x: SomeTensor2) =
projectSU(r[], x[])
template projectTAH*(r: var SomeTensor) =
projectTAH(r[])
template projectTAH*(r: var SomeTensor, x: SomeTensor2) =
projectTAH(r[], x[])
template checkU*(x: SomeTensor): auto = checkU(x[])
template checkSU*(x: SomeTensor): auto = checkSU(x[])

template norm2*(x: SomeTensor): auto = norm2(x[])
template norm2*(r: var auto, x: SomeTensor) = norm2(r, x[])
template inorm2*(r: var auto, x: SomeTensor) = inorm2(r, x[])
template dot*(x: SomeTensor, y: SomeTensor2): auto =
dot(x[], y[])
template idot*(r: var auto, x: SomeTensor2, y: SomeTensor3) =
idot(r, x[], y[])
template redot*(x: SomeTensor, y: SomeTensor2): auto =
redot(x[], y[])
template dot*(x: SomeTensor, y: SomeTensor2): auto = dot(x[], y[])
template idot*(r: var auto, x: SomeTensor2, y: SomeTensor3) = idot(r, x[], y[])
template redot*(x: SomeTensor, y: SomeTensor2): auto = redot(x[], y[])
template trace*(x: SomeTensor): auto = trace(x[])
template exp*(x: SomeTensor): auto =
tensorObj(x.kind, exp(x[]))
template expDeriv*(x: SomeTensor, c: SomeTensor2): auto =
tensorObj(x.kind, expDeriv(x[], c[]))
template ln*(x: SomeTensor): auto =
tensorObj(x.kind, ln(x[]))

template exp*[X:SomeTensor](x: X): auto =
tensorObj(X.kind, exp(x[]))
template expDeriv*[X:SomeTensor,C:SomeTensor2](x: X, c: C): auto =
tensorObj(X.kind, expDeriv(x[], c[]))
template ln*[X:SomeTensor](x: X): auto =
tensorObj(X.kind, ln(x[]))

0 comments on commit 941a21c

Please sign in to comment.