@@ -17,7 +17,6 @@ import (
17
17
type Config struct {
18
18
Package string `yaml:"package"`
19
19
Import []string `yaml:"import"`
20
- JSONPackage string `yaml:"json-package"`
21
20
ScalarMapper map [string ]string `yaml:"scalarMapper"`
22
21
}
23
22
@@ -51,11 +50,12 @@ func unwrapType(typ types.Type) types.Type {
51
50
type generator struct {
52
51
schema * types.Schema
53
52
schemaTypes []types.NamedType
53
+ extends []GeneratorExtend
54
54
config Config
55
- w util.Output
55
+ w * util.Output
56
56
}
57
57
58
- func newGenerator (schema * types.Schema , config Config , out io.Writer ) generator {
58
+ func newGenerator (schema * types.Schema , config Config , out io.Writer , extends [] GeneratorExtend ) generator {
59
59
schemaTypes := make ([]types.NamedType , 0 , len (schema .Types ))
60
60
for _ , typ := range schema .Types {
61
61
schemaTypes = append (schemaTypes , typ )
@@ -67,17 +67,38 @@ func newGenerator(schema *types.Schema, config Config, out io.Writer) generator
67
67
return generator {
68
68
schema : schema ,
69
69
schemaTypes : schemaTypes ,
70
+ extends : extends ,
70
71
config : config ,
71
- w : util.Output {Writer : out },
72
+ w : & util.Output {Writer : out },
72
73
}
73
74
}
74
75
76
+ type GeneratorExtend interface {
77
+ Imports () []string
78
+ GenScalar (out * util.Output , scalarType * types.ScalarTypeDefinition , scalarGoType string )
79
+ GenEnum (out * util.Output , enumType * types.EnumTypeDefinition )
80
+ GenObject (out * util.Output , objectType * types.ObjectTypeDefinition )
81
+ GenUnion (out * util.Output , unionType * types.Union )
82
+ }
83
+
75
84
func (g generator ) genHeader () {
76
85
// package
77
86
g .w .Out ("package %s\n \n " , g .config .Package )
78
87
// import
79
88
g .w .Out ("import (\n " )
80
- for _ , imp := range append (g .config .Import , g .config .JSONPackage , "fmt" , "reflect" , "strconv" ) {
89
+ imports := map [string ]bool {
90
+ "fmt" : true ,
91
+ "strconv" : true ,
92
+ }
93
+ for _ , ex := range g .extends {
94
+ for _ , imp := range ex .Imports () {
95
+ imports [imp ] = true
96
+ }
97
+ }
98
+ for _ , imp := range g .config .Import {
99
+ imports [imp ] = true
100
+ }
101
+ for imp := range imports {
81
102
g .w .Out ("\t %q\n " , imp )
82
103
}
83
104
g .w .Out (")\n \n " )
@@ -94,115 +115,44 @@ func (g generator) genScalars() {
94
115
if ! has {
95
116
log .Fatalf ("miss mapping for scalar type %q" , scalarType .TypeName ())
96
117
}
97
-
98
118
switch scalarGoType {
99
119
case "uint8" , "uint16" , "uint32" , "uint64" :
100
120
g .w .Out ("type %s %s\n " , scalarType .TypeName (), scalarGoType )
101
121
g .w .Outf (`
102
- func (s *#{name}) UnmarshalJSON(raw []byte) error {
103
- if i, err := UnmarshalJSONUInt(raw); err != nil {
104
- return err
105
- } else {
106
- *s = #{name}(i)
107
- return nil
108
- }
109
- }
110
-
111
- func (s #{name}) MarshalJSON() ([]byte, error) {
112
- return json.Marshal(s.String())
122
+ func (s #{name}) String() string {
123
+ return strconv.FormatUint(uint64(s), 10)
113
124
}
114
-
115
- func (s *#{name}) String() string {
116
- return strconv.FormatUint(uint64(*s), 10)
117
- }
118
-
119
125
` , "name" , scalarType .TypeName ())
120
126
case "int8" , "int16" , "int32" , "int64" :
121
127
g .w .Out ("type %s %s\n " , scalarType .TypeName (), scalarGoType )
122
128
g .w .Outf (`
123
- func (s *#{name}) UnmarshalJSON(raw []byte) error {
124
- if i, err := UnmarshalJSONInt(raw); err != nil {
125
- return err
126
- } else {
127
- *s = #{name}(i)
128
- return nil
129
- }
129
+ func (s #{name}) String() string {
130
+ return strconv.FormatInt(int64(s), 10)
130
131
}
131
-
132
- func (s #{name}) MarshalJSON() ([]byte, error) {
133
- return json.Marshal(s.String())
134
- }
135
-
136
- func (s *#{name}) String() string {
137
- return strconv.FormatInt(int64(*s), 10)
138
- }
139
-
140
132
` , "name" , scalarType .TypeName ())
141
133
case "float32" , "float64" :
142
134
g .w .Out ("type %s %s\n " , scalarType .TypeName (), scalarGoType )
143
135
g .w .Outf (`
144
- func (s *#{name}) UnmarshalJSON(raw []byte) error {
145
- if f, err := UnmarshalJSONFloat(raw); err != nil {
146
- return err
147
- } else {
148
- *s = #{name}(f)
149
- return nil
150
- }
151
- }
152
-
153
- func (s #{name}) MarshalJSON() ([]byte, error) {
154
- return json.Marshal(s.String())
155
- }
156
-
157
- func (s *#{name}) String() string {
158
- return strconv.FormatFloat(float64(*s), 'f', 20, 64)
136
+ func (s #{name}) String() string {
137
+ return strconv.FormatFloat(float64(s), 'f', 20, 64)
159
138
}
160
-
161
139
` , "name" , scalarType .TypeName ())
162
140
case "string" :
163
141
g .w .Out ("type %s %s\n " , scalarType .TypeName (), scalarGoType )
164
142
case "bool" :
165
143
g .w .Out ("type %s %s\n " , scalarType .TypeName (), scalarGoType )
166
144
g .w .Outf (`
167
- func (s * #{name}) String() string {
168
- return strconv.FormatBool(bool(* s))
145
+ func (s #{name}) String() string {
146
+ return strconv.FormatBool(bool(s))
169
147
}
170
-
171
148
` , "name" , scalarType .TypeName ())
172
149
default :
173
150
g .w .Out ("type %s struct { %s }\n " , scalarType .TypeName (), scalarGoType )
174
- g .w .Outf (`
175
- func (s #{name}) MarshalStructpb() *structpb.Value {
176
- return structpb.NewStringValue(s.String())
177
- }
178
-
179
- ` , "name" , scalarType .TypeName ())
151
+ }
152
+ for _ , ex := range g .extends {
153
+ ex .GenScalar (g .w , scalarType , scalarGoType )
180
154
}
181
155
}
182
- g .w .Out (`
183
-
184
- func UnmarshalJSONUInt(raw []byte) (uint64, error) {
185
- if len(raw) >= 2 && raw[0] == '"' && raw[len(raw)-1] == '"' {
186
- raw = raw[1:len(raw)-1]
187
- }
188
- return strconv.ParseUint(string(raw), 10, 64)
189
- }
190
-
191
- func UnmarshalJSONInt(raw []byte) (int64, error) {
192
- if len(raw) >= 2 && raw[0] == '"' && raw[len(raw)-1] == '"' {
193
- raw = raw[1:len(raw)-1]
194
- }
195
- return strconv.ParseInt(string(raw), 10, 64)
196
- }
197
-
198
- func UnmarshalJSONFloat(raw []byte) (float64, error) {
199
- if len(raw) >= 2 && raw[0] == '"' && raw[len(raw)-1] == '"' {
200
- raw = raw[1:len(raw)-1]
201
- }
202
- return strconv.ParseFloat(string(raw), 64)
203
- }
204
-
205
- ` )
206
156
}
207
157
208
158
func (g generator ) genEnums () {
@@ -218,24 +168,11 @@ func (g generator) genEnums() {
218
168
for _ , val := range enumType .EnumValuesDefinition {
219
169
g .w .Out (" %q,\n " , val .EnumValue )
220
170
}
221
- g .w .Outf (`}
222
-
223
- func (e *#{name}) UnmarshalJSON(raw []byte) error {
224
- var val string
225
- if err := json.Unmarshal(raw, &val); err != nil {
226
- return err
227
- }
228
- for _, v := range #{name}Values {
229
- if v == val {
230
- *e = #{name}(val)
231
- return nil
171
+ g .w .Out ("}\n \n " )
172
+ for _ , ex := range g .extends {
173
+ ex .GenEnum (g .w , enumType )
232
174
}
233
175
}
234
- return fmt.Errorf("invalid value %%q for enum type #{name}", val)
235
- }
236
-
237
- ` , "name" , enumType .TypeName ())
238
- }
239
176
}
240
177
241
178
func (g generator ) genObjects () {
@@ -254,7 +191,10 @@ func (g generator) genObjects() {
254
191
g .w .Out ("\t // SCHEMA: %s %s\n " , field .Name , field .Type .String ())
255
192
g .w .Out ("\t %s %s `%s`\n " , goFieldName , convertToGoType (field .Type , false ), tags )
256
193
}
257
- g .w .Out ("}\n \n " )
194
+ g .w .Out ("}\n " )
195
+ for _ , ex := range g .extends {
196
+ ex .GenObject (g .w , objectType )
197
+ }
258
198
}
259
199
}
260
200
@@ -268,88 +208,10 @@ func (g generator) genUnions() {
268
208
g .w .Out ("\t *%s\n " , mem .TypeName ())
269
209
}
270
210
g .w .Out ("}\n \n " )
271
- g .w .Outf (`
272
- func (u *#{name}) UnmarshalJSON(raw []byte) error {
273
- return UnmarshalJSONUnion(raw, u)
274
- }
275
-
276
- func (u #{name}) MarshalJSON() ([]byte, error) {
277
- return MarshalJSONUnion(u)
278
- }
279
-
280
- ` , "name" , unionType .TypeName ())
281
- }
282
- g .w .Outf (`
283
-
284
- func UnmarshalJSONUnion(raw []byte, unionObj any) error {
285
- var union struct {
286
- TypeName string ` + "`" + `json:"__typename"` + "`" + `
287
- }
288
- if err := json.Unmarshal(raw, &union); err != nil {
289
- return err
290
- }
291
- if union.TypeName == "" {
292
- return nil
293
- }
294
- pv := reflect.ValueOf(unionObj)
295
- if pv.Kind() != reflect.Pointer || pv.IsNil() {
296
- return &json.InvalidUnmarshalError{Type: reflect.TypeOf(unionObj)}
297
- }
298
- rv := pv.Elem()
299
- rt := rv.Type()
300
- if _, has := rt.FieldByName("#{typeFieldName}"); !has {
301
- return fmt.Errorf("%%s is not an union type because miss field #{typeFieldName}", rt.Name())
302
- }
303
- rv.FieldByName("#{typeFieldName}").SetString(union.TypeName)
304
- for i := 0; i < rt.NumField(); i++ {
305
- if rt.Field(i).Name == union.TypeName {
306
- if rt.Field(i).Type.Kind() != reflect.Pointer {
307
- return fmt.Errorf("member %%s of union type %%T should be an pointer", union.TypeName, unionObj)
308
- }
309
- fv := reflect.New(rt.Field(i).Type.Elem())
310
- if err := json.Unmarshal(raw, fv.Interface()); err != nil {
311
- return err
312
- }
313
- rv.Field(i).Set(fv)
314
- return nil
211
+ for _ , ex := range g .extends {
212
+ ex .GenUnion (g .w , unionType )
315
213
}
316
214
}
317
- return fmt.Errorf("union type %%T do not have member %%q", unionObj, union.TypeName)
318
- }
319
-
320
- func MarshalJSONUnion(unionObj any) ([]byte, error) {
321
- val := reflect.ValueOf(unionObj)
322
- vt := val.Type()
323
- if _, has := vt.FieldByName("#{typeFieldName}"); !has {
324
- return nil, fmt.Errorf("%%s is not an union type because miss field #{typeFieldName}", vt.Name())
325
- }
326
- typeName := val.FieldByName("#{typeFieldName}").Interface().(string)
327
- if typeName == "" {
328
- return json.Marshal(nil)
329
- }
330
- if _, has := vt.FieldByName(typeName); !has {
331
- return json.Marshal(map[string]string{"__typename": typeName})
332
- }
333
- subVal := val.FieldByName(typeName)
334
- if subVal.IsNil() {
335
- return nil, fmt.Errorf("%%s can not be nil", typeName)
336
- }
337
- subVal = subVal.Elem()
338
- subTyp := subVal.Type()
339
- fields := make([]reflect.StructField, subVal.NumField()+1)
340
- fields[0], _ = vt.FieldByName("#{typeFieldName}")
341
- for i := 0; i < subTyp.NumField(); i++ {
342
- fields[i+1] = subTyp.Field(i)
343
- }
344
- merged := reflect.New(reflect.StructOf(fields)).Elem()
345
- merged.Field(0).SetString(typeName)
346
- for i := 0; i < subVal.NumField(); i++ {
347
- merged.Field(i + 1).Set(subVal.Field(i))
348
- }
349
- return json.Marshal(merged.Interface())
350
- }
351
-
352
- ` , "typeFieldName" , util .UnionTypeFieldName )
353
215
}
354
216
355
217
func (g generator ) genInputObjects () {
@@ -445,5 +307,9 @@ func main() {
445
307
defer out .Close ()
446
308
447
309
// gen
448
- newGenerator (schema .ASTSchema (), conf , out ).Gen ()
310
+ extends := []GeneratorExtend {
311
+ generatorExtendJSON {},
312
+ generatorExtendStructpb {},
313
+ }
314
+ newGenerator (schema .ASTSchema (), conf , out , extends ).Gen ()
449
315
}
0 commit comments