Skip to content

Commit

Permalink
WIP check arg type against defs at parsing time
Browse files Browse the repository at this point in the history
  • Loading branch information
jmao-denver committed Mar 15, 2024
1 parent 46b5044 commit f9c883a
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1939,7 +1939,7 @@ private static boolean isAssociativitySafeExpression(Expression expr) {
* @return {@code true} if a conversion from {@code original} to {@code target} is a widening conversion; otherwise,
* {@code false}.
*/
static boolean isWideningPrimitiveConversion(Class<?> original, Class<?> target) {
public static boolean isWideningPrimitiveConversion(Class<?> original, Class<?> target) {
if (original == null || !original.isPrimitive() || target == null || !target.isPrimitive()
|| original.equals(void.class) || target.equals(void.class)) {
throw new IllegalArgumentException("Arguments must be a primitive type (excluding void)!");
Expand Down Expand Up @@ -2498,13 +2498,36 @@ public Class<?> visit(MethodCallExpr n, VisitArgs printer) {

// Attempt python function call vectorization.
if (scopeType != null && PyCallableWrapper.class.isAssignableFrom(scopeType)) {
verifyPyCallableArguments(n, argTypes);
tryVectorizePythonCallable(n, scopeType, convertedArgExpressions, argTypes);
}

return calculateMethodReturnTypeUsingGenerics(scopeType, n.getScope().orElse(null), method, expressionTypes,
typeArguments);
}

private void verifyPyCallableArguments(@NotNull MethodCallExpr n, @NotNull Class<?>[] argTypes) {
final String invokedMethodName = n.getNameAsString();

if (GET_ATTRIBUTE_METHOD_NAME.equals(invokedMethodName)) {
// Only PyCallableWrapper.getAttribute()/PyCallableWrapper.call() may be invoked from the query language.
// UDF type checks are not currently supported for getAttribute() calls.
return;
}
if (!n.containsData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS)) {
return;
}
final PyCallableDetails pyCallableDetails = n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS);
final String pyMethodName = pyCallableDetails.pythonMethodName;
final Object paramValueRaw = queryScopeVariables.get(pyMethodName);
if (!(paramValueRaw instanceof PyCallableWrapper)) {
return;
}
final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) paramValueRaw;
pyCallableWrapper.parseSignature();
pyCallableWrapper.verifyArguments(argTypes);
}

private Optional<CastExpr> makeCastExpressionForPyCallable(Class<?> retType, MethodCallExpr callMethodCall) {
if (retType.isPrimitive()) {
return Optional.of(new CastExpr(
Expand Down Expand Up @@ -2726,11 +2749,10 @@ private void checkVectorizability(@NotNull final MethodCallExpr n,
}
}

List<Class<?>> paramTypes = pyCallableWrapper.getParamTypes();
if (paramTypes.size() != expressions.length) {
if (pyCallableWrapper.getNumParameters() != expressions.length) {
// note vectorization doesn't handle Python variadic arguments
throw new PythonCallVectorizationFailure("Python function argument count mismatch: " + n + " "
+ paramTypes.size() + " vs. " + expressions.length);
+ pyCallableWrapper.getNumParameters() + " vs. " + expressions.length);
}
}

Expand All @@ -2739,10 +2761,9 @@ private void prepareVectorizationArgs(
Expression[] expressions,
Class<?>[] argTypes,
PyCallableWrapper pyCallableWrapper) {
List<Class<?>> paramTypes = pyCallableWrapper.getParamTypes();
if (paramTypes.size() != expressions.length) {
if (pyCallableWrapper.getNumParameters() != expressions.length) {
throw new PythonCallVectorizationFailure("Python function argument count mismatch: " + n + " "
+ paramTypes.size() + " vs. " + expressions.length);
+ pyCallableWrapper.getNumParameters() + " vs. " + expressions.length);
}

pyCallableWrapper.initializeChunkArguments();
Expand All @@ -2764,10 +2785,11 @@ private void prepareVectorizationArgs(
throw new IllegalStateException("Vectorizability check failed: " + n);
}

if (!isSafelyCoerceable(argTypes[i], paramTypes.get(i))) {
throw new PythonCallVectorizationFailure("Python vectorized function argument type mismatch: " + n + " "
+ argTypes[i].getSimpleName() + " -> " + paramTypes.get(i).getSimpleName());
}
// TODO related to core#709, but should be covered by PyCallableWrapper.verifyArguments, needs to verify
// if (!isSafelyCoerceable(argTypes[i], paramTypes.get(i))) {
// throw new PythonCallVectorizationFailure("Python vectorized function argument type mismatch: " + n + " "
// + argTypes[i].getSimpleName() + " -> " + paramTypes.get(i).getSimpleName());
// }
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.jpy.PyObject;

import java.util.List;
import java.util.Set;

/**
* Created by rbasralian on 8/12/23
Expand All @@ -19,7 +20,9 @@ public interface PyCallableWrapper {

Object call(Object... args);

List<Class<?>> getParamTypes();
List<Parameter> getParameters();

int getNumParameters();

boolean isVectorized();

Expand All @@ -33,6 +36,8 @@ public interface PyCallableWrapper {

Class<?> getReturnType();

void verifyArguments(Class<?>[] argTypes);

abstract class ChunkArgument {
private final Class<?> type;

Expand Down Expand Up @@ -88,4 +93,41 @@ public Object getValue() {
}

boolean isVectorizableReturnType();

class Signature {
private final List<Parameter> parameters;
private final Class<?> returnType;

public Signature(List<Parameter> parameters, Class<?> returnType) {
this.parameters = parameters;
this.returnType = returnType;
}

public List<Parameter> getParameters() {
return parameters;
}

public Class<?> getReturnType() {
return returnType;
}
}
class Parameter {
private final String name;
private final Set<Class<?>> possibleTypes;


public Parameter(String name, Set<Class<?>> possibleTypes) {
this.name = name;
this.possibleTypes = possibleTypes;
}

public Set<Class<?>> getPossibleTypes() {
return possibleTypes;
}

public String getName() {
return name;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,9 @@
import org.jpy.PyObject;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.*;

import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.isWideningPrimitiveConversion;

/**
* When given a pyObject that is a callable, we stick it inside the callable wrapper, which implements a call() varargs
Expand All @@ -30,6 +27,7 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper {
private static final PyModule dh_udf_module = PyModule.importModule("deephaven._udf");

private static final Map<Character, Class<?>> numpyType2JavaClass = new HashMap<>();
private static final Map<Character, Class<?>> numpyType2JavaArrayClass = new HashMap<>();

static {
numpyType2JavaClass.put('b', byte.class);
Expand All @@ -43,8 +41,22 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper {
numpyType2JavaClass.put('U', String.class);
numpyType2JavaClass.put('M', Instant.class);
numpyType2JavaClass.put('O', Object.class);

numpyType2JavaArrayClass.put('b', byte[].class);
numpyType2JavaArrayClass.put('h', short[].class);
numpyType2JavaArrayClass.put('H', char[].class);
numpyType2JavaArrayClass.put('i', int[].class);
numpyType2JavaArrayClass.put('l', long[].class);
numpyType2JavaArrayClass.put('f', float[].class);
numpyType2JavaArrayClass.put('d', double[].class);
numpyType2JavaArrayClass.put('?', boolean[].class);
numpyType2JavaArrayClass.put('U', String[].class);
numpyType2JavaArrayClass.put('M', Instant[].class);
numpyType2JavaArrayClass.put('O', Object[].class);
}



/**
* Ensure that the class initializer runs.
*/
Expand Down Expand Up @@ -75,8 +87,8 @@ public boolean isVectorizableReturnType() {

private final PyObject pyCallable;

private String signature = null;
private List<Class<?>> paramTypes;
private String signatureString = null;
private List<Parameter> parameters = new ArrayList<>();
private Class<?> returnType;
private boolean vectorizable = false;
private boolean vectorized = false;
Expand Down Expand Up @@ -168,40 +180,61 @@ private void prepareSignature() {
vectorized = false;
}
pyUdfDecoratedCallable = dh_udf_module.call("_py_udf", unwrapped);
signature = pyUdfDecoratedCallable.getAttribute("signature").toString();
signatureString = pyUdfDecoratedCallable.getAttribute("signature").toString();
}


@Override
public void parseSignature() {
if (signature != null) {
if (signatureString != null) {
return;
}

prepareSignature();

// the 'types' field of a vectorized function follows the pattern of '[ilhfdb?O]*->[ilhfdb?O]',
// eg. [ll->d] defines two int64 (long) arguments and a double return type.
if (signature == null || signature.isEmpty()) {
if (signatureString == null || signatureString.isEmpty()) {
throw new IllegalStateException("Signature should always be available.");
}

List<Class<?>> paramTypes = new ArrayList<>();
for (char numpyTypeChar : signature.toCharArray()) {
if (numpyTypeChar != '-') {
Class<?> paramType = numpyType2JavaClass.get(numpyTypeChar);
if (paramType == null) {
throw new IllegalStateException(
"Parameters of vectorized functions should always be of integral, floating point, boolean, String, or Object type: "
+ numpyTypeChar + " of " + signature);
// List<Class<?>> paramTypes = new ArrayList<>();
// for (char numpyTypeChar : signatureString.toCharArray()) {
// if (numpyTypeChar != '-') {
// Class<?> paramType = numpyType2JavaClass.get(numpyTypeChar);
// if (paramType == null) {
// throw new IllegalStateException(
// "Parameters of vectorized functions should always be of integral, floating point, boolean, String, or Object type: "
// + numpyTypeChar + " of " + signatureString);
// }
// paramTypes.add(paramType);
// } else {
// break;
// }
// }
// this.paramTypes = paramTypes;
String pyEncodedParamsStr = signatureString.split("->")[0];
if (!pyEncodedParamsStr.isEmpty()){
String[] pyEncodedParams = pyEncodedParamsStr.split(",");
for (int i = 0; i < pyEncodedParams.length; i++) {
String[] paramDetail = pyEncodedParams[i].split(":");
String paramName = paramDetail[0];
String paramTypeCodes = paramDetail[1];
Set<Class<?>> possibleTypes = new HashSet<>();
for (int ti = 0; ti < paramTypeCodes.length(); ti++) {
char typeCode = paramTypeCodes.charAt(ti);
if (typeCode == '[') {
// skip the array type code
ti++;
possibleTypes.add(numpyType2JavaArrayClass.get(paramTypeCodes.charAt(ti)));
} else {
possibleTypes.add(numpyType2JavaClass.get(typeCode));
}
}
paramTypes.add(paramType);
} else {
break;
parameters.add(new Parameter(paramName, possibleTypes));
}
}

this.paramTypes = paramTypes;

returnType = pyUdfDecoratedCallable.getAttribute("return_type", null);
if (returnType == null) {
throw new IllegalStateException(
Expand All @@ -213,6 +246,35 @@ public void parseSignature() {
}
}

private boolean isSafelyCastable(Set<Class<?>> types, Class<?> type) {
for (Class<?> t : types) {
if (t.isAssignableFrom(type)) {
return true;
}
if (t.isPrimitive() && type.isPrimitive() && isWideningPrimitiveConversion(type, t)) {
return true;
}
}
return false;
}


public void verifyArguments(Class<?>[] argTypes) {
String callableName = pyCallable.getAttribute("__name__").toString();

if (argTypes.length != parameters.size()) {
throw new IllegalArgumentException(callableName + ": " + "Expected " + parameters.size() + " arguments, got " + argTypes.length);
}
for (int i = 0; i < argTypes.length; i++) {
Set<Class<?>> types = parameters.get(i).getPossibleTypes();
if (!types.contains(argTypes[i]) && !types.contains(Object.class) && !isSafelyCastable(types, argTypes[i])) {
throw new IllegalArgumentException(
callableName + ": " + "Expected argument (" + parameters.get(i).getName() + ") to be one of " + parameters.get(i).getPossibleTypes() + ", got "
+ argTypes[i]);
}
}
};

// In vectorized mode, we want to call the vectorized function directly.
public PyObject vectorizedCallable() {
if (numbaVectorized || vectorized) {
Expand All @@ -230,8 +292,13 @@ public Object call(Object... args) {
}

@Override
public List<Class<?>> getParamTypes() {
return paramTypes;
public List<Parameter> getParameters() {
return parameters;
}

@Override
public int getNumParameters() {
return parameters.size();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.deephaven.engine.util.PyCallableWrapper;
import org.jpy.PyObject;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

Expand Down Expand Up @@ -45,6 +46,15 @@ public Object call(Object... args) {
}

@Override
public List<Parameter> getParameters() {
return new ArrayList<>();
}

@Override
public int getNumParameters() {
return 0;
}

public List<Class<?>> getParamTypes() {
return parameterTypes;
}
Expand Down Expand Up @@ -75,6 +85,11 @@ public Class<?> getReturnType() {
return Object.class;
}

@Override
public void verifyArguments(Class<?>[] argTypes) {

}

@Override
public boolean isVectorizableReturnType() {
return false;
Expand Down
Loading

0 comments on commit f9c883a

Please sign in to comment.