diff --git a/oracle.go b/oracle.go index 2945442..3d3da3c 100644 --- a/oracle.go +++ b/oracle.go @@ -82,6 +82,44 @@ func GetStringExpr(value string, quotes ...bool) clause.Expr { return gorm.Expr(value) } +// AddSessionParams setting database connection session parameters +func AddSessionParams(db *sql.DB, params map[string]string) (keys []string, err error) { + if db == nil { + return + } + if _, ok := db.Driver().(*go_ora.OracleDriver); !ok { + return + } + + for key, value := range params { + if key == "" || value == "" { + continue + } + if err = go_ora.AddSessionParam(db, key, value); err != nil { + return + } + keys = append(keys, key) + } + return +} + +// DelSessionParams remove session parameters +func DelSessionParams(db *sql.DB, keys []string) { + if db == nil { + return + } + if _, ok := db.Driver().(*go_ora.OracleDriver); !ok { + return + } + + for _, key := range keys { + if key == "" { + continue + } + go_ora.DelSessionParam(db, key) + } +} + func convertCustomType(val interface{}) interface{} { rv := reflect.ValueOf(val) ri := rv.Interface() diff --git a/oracle_test.go b/oracle_test.go index 179e273..46e615b 100644 --- a/oracle_test.go +++ b/oracle_test.go @@ -1,6 +1,7 @@ package oracle import ( + "database/sql" "log" "os" "reflect" @@ -111,3 +112,52 @@ func openTestConnection(ignoreCase, namingCase bool) (db *gorm.DB, err error) { } return } + +func TestAddSessionParams(t *testing.T) { + db, err := openTestConnection(true, false) + if err != nil { + t.Fatal(err) + } + var sqlDB *sql.DB + if sqlDB, err = db.DB(); err != nil { + t.Fatal(err) + } + type args struct { + params map[string]string + } + tests := []struct { + name string + args args + }{ + {name: "TimeParams", args: args{params: map[string]string{ + "TIME_ZONE": "+08:00", // alter session set TIME_ZONE = '+08:00'; + "NLS_DATE_FORMAT": "YYYY-MM-DD", // alter session set NLS_DATE_FORMAT = 'YYYY-MM-DD'; + "NLS_TIME_FORMAT": "HH24:MI:SSXFF", // alter session set NLS_TIME_FORMAT = 'HH24:MI:SS.FF3'; + "NLS_TIMESTAMP_FORMAT": "YYYY-MM-DD HH24:MI:SSXFF", // alter session set NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF3'; + "NLS_TIME_TZ_FORMAT": "HH24:MI:SS.FF TZR", // alter session set NLS_TIME_TZ_FORMAT = 'HH24:MI:SS.FF3 TZR'; + "NLS_TIMESTAMP_TZ_FORMAT": "YYYY-MM-DD HH24:MI:SSXFF TZR", // alter session set NLS_TIMESTAMP_TZ_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF3 TZR'; + }}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + //queryTime := `SELECT SYSDATE FROM DUAL` + queryTime := `SELECT CAST(SYSDATE AS VARCHAR(30)) AS D FROM DUAL` + var timeStr string + if err = db.Raw(queryTime).Row().Scan(&timeStr); err != nil { + t.Fatal(err) + } + t.Logf("SYSDATE 1: %s", timeStr) + + var keys []string + if keys, err = AddSessionParams(sqlDB, tt.args.params); err != nil { + t.Fatalf("AddSessionParams() error = %v", err) + } + if err = db.Raw(queryTime).Row().Scan(&timeStr); err != nil { + t.Fatal(err) + } + defer DelSessionParams(sqlDB, keys) + t.Logf("SYSDATE 2: %s", timeStr) + t.Logf("keys: %#v", keys) + }) + } +}