Skip to content

Commit

Permalink
Fix for Delta(int, Index)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 22, 2025
1 parent 27cf83d commit fdff8da
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 12 deletions.
12 changes: 1 addition & 11 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,19 +982,9 @@ def __new__(cls, i, j, dtype=None):
return one

# Fixed indices
if isinstance(i, int) and isinstance(j, int):
if isinstance(i, Integral) and isinstance(j, Integral):
return one if i == j else Zero()

if isinstance(i, int):
expr = numpy.full((j.extent), Zero(), dtype=object)
expr[i] = one
return Indexed(ListTensor(expr), (j,))

if isinstance(j, int):
expr = numpy.full((i.extent), Zero(), dtype=object)
expr[j] = one
return Indexed(ListTensor(expr), (i,))

self = super(Delta, cls).__new__(cls)
self.i = i
self.j = j
Expand Down
2 changes: 1 addition & 1 deletion gem/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def _replace_delta_delta(node, self):
return Indexed(Identity(size), (i, j))
else:
def expression(index):
if isinstance(index, int):
if isinstance(index, Integral):
return Literal(index)
elif isinstance(index, VariableIndex):
return index.expression
Expand Down

0 comments on commit fdff8da

Please sign in to comment.