1
2
3
4
5
6
7 package sql
8
9 import (
10 "bytes"
11 "database/sql/driver"
12 "errors"
13 "fmt"
14 "reflect"
15 "strconv"
16 "time"
17 "unicode"
18 "unicode/utf8"
19 )
20
21 var errNilPtr = errors.New("destination pointer is nil")
22
23 func describeNamedValue(nv *driver.NamedValue) string {
24 if len(nv.Name) == 0 {
25 return fmt.Sprintf("$%d", nv.Ordinal)
26 }
27 return fmt.Sprintf("with name %q", nv.Name)
28 }
29
30 func validateNamedValueName(name string) error {
31 if len(name) == 0 {
32 return nil
33 }
34 r, _ := utf8.DecodeRuneInString(name)
35 if unicode.IsLetter(r) {
36 return nil
37 }
38 return fmt.Errorf("name %q does not begin with a letter", name)
39 }
40
41
42
43
44 type ccChecker struct {
45 cci driver.ColumnConverter
46 want int
47 }
48
49 func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
50 if c.cci == nil {
51 return driver.ErrSkip
52 }
53
54
55
56 index := nv.Ordinal - 1
57 if c.want <= index {
58 return nil
59 }
60
61
62
63
64 if vr, ok := nv.Value.(driver.Valuer); ok {
65 sv, err := callValuerValue(vr)
66 if err != nil {
67 return err
68 }
69 if !driver.IsValue(sv) {
70 return fmt.Errorf("non-subset type %T returned from Value", sv)
71 }
72 nv.Value = sv
73 }
74
75
76
77
78
79
80
81
82 var err error
83 arg := nv.Value
84 nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
85 if err != nil {
86 return err
87 }
88 if !driver.IsValue(nv.Value) {
89 return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value)
90 }
91 return nil
92 }
93
94
95
96
97 func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
98 nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
99 return err
100 }
101
102
103
104
105
106
107
108 func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []any) ([]driver.NamedValue, error) {
109 nvargs := make([]driver.NamedValue, len(args))
110
111
112
113
114 want := -1
115
116 var si driver.Stmt
117 var cc ccChecker
118 if ds != nil {
119 si = ds.si
120 want = ds.si.NumInput()
121 cc.want = want
122 }
123
124
125
126
127
128 nvc, ok := si.(driver.NamedValueChecker)
129 if !ok {
130 nvc, ok = ci.(driver.NamedValueChecker)
131 }
132 cci, ok := si.(driver.ColumnConverter)
133 if ok {
134 cc.cci = cci
135 }
136
137
138
139
140
141
142 var err error
143 var n int
144 for _, arg := range args {
145 nv := &nvargs[n]
146 if np, ok := arg.(NamedArg); ok {
147 if err = validateNamedValueName(np.Name); err != nil {
148 return nil, err
149 }
150 arg = np.Value
151 nv.Name = np.Name
152 }
153 nv.Ordinal = n + 1
154 nv.Value = arg
155
156
157
158
159
160
161
162
163
164
165
166
167 checker := defaultCheckNamedValue
168 nextCC := false
169 switch {
170 case nvc != nil:
171 nextCC = cci != nil
172 checker = nvc.CheckNamedValue
173 case cci != nil:
174 checker = cc.CheckNamedValue
175 }
176
177 nextCheck:
178 err = checker(nv)
179 switch err {
180 case nil:
181 n++
182 continue
183 case driver.ErrRemoveArgument:
184 nvargs = nvargs[:len(nvargs)-1]
185 continue
186 case driver.ErrSkip:
187 if nextCC {
188 nextCC = false
189 checker = cc.CheckNamedValue
190 } else {
191 checker = defaultCheckNamedValue
192 }
193 goto nextCheck
194 default:
195 return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err)
196 }
197 }
198
199
200
201 if want != -1 && len(nvargs) != want {
202 return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
203 }
204
205 return nvargs, nil
206 }
207
208
209
210 func convertAssign(dest, src any) error {
211 return convertAssignRows(dest, src, nil)
212 }
213
214
215
216
217
218
219 func convertAssignRows(dest, src any, rows *Rows) error {
220
221 switch s := src.(type) {
222 case string:
223 switch d := dest.(type) {
224 case *string:
225 if d == nil {
226 return errNilPtr
227 }
228 *d = s
229 return nil
230 case *[]byte:
231 if d == nil {
232 return errNilPtr
233 }
234 *d = []byte(s)
235 return nil
236 case *RawBytes:
237 if d == nil {
238 return errNilPtr
239 }
240 *d = append((*d)[:0], s...)
241 return nil
242 }
243 case []byte:
244 switch d := dest.(type) {
245 case *string:
246 if d == nil {
247 return errNilPtr
248 }
249 *d = string(s)
250 return nil
251 case *any:
252 if d == nil {
253 return errNilPtr
254 }
255 *d = bytes.Clone(s)
256 return nil
257 case *[]byte:
258 if d == nil {
259 return errNilPtr
260 }
261 *d = bytes.Clone(s)
262 return nil
263 case *RawBytes:
264 if d == nil {
265 return errNilPtr
266 }
267 *d = s
268 return nil
269 }
270 case time.Time:
271 switch d := dest.(type) {
272 case *time.Time:
273 *d = s
274 return nil
275 case *string:
276 *d = s.Format(time.RFC3339Nano)
277 return nil
278 case *[]byte:
279 if d == nil {
280 return errNilPtr
281 }
282 *d = []byte(s.Format(time.RFC3339Nano))
283 return nil
284 case *RawBytes:
285 if d == nil {
286 return errNilPtr
287 }
288 *d = s.AppendFormat((*d)[:0], time.RFC3339Nano)
289 return nil
290 }
291 case decimalDecompose:
292 switch d := dest.(type) {
293 case decimalCompose:
294 return d.Compose(s.Decompose(nil))
295 }
296 case nil:
297 switch d := dest.(type) {
298 case *any:
299 if d == nil {
300 return errNilPtr
301 }
302 *d = nil
303 return nil
304 case *[]byte:
305 if d == nil {
306 return errNilPtr
307 }
308 *d = nil
309 return nil
310 case *RawBytes:
311 if d == nil {
312 return errNilPtr
313 }
314 *d = nil
315 return nil
316 }
317
318 case driver.Rows:
319 switch d := dest.(type) {
320 case *Rows:
321 if d == nil {
322 return errNilPtr
323 }
324 if rows == nil {
325 return errors.New("invalid context to convert cursor rows, missing parent *Rows")
326 }
327 rows.closemu.Lock()
328 *d = Rows{
329 dc: rows.dc,
330 releaseConn: func(error) {},
331 rowsi: s,
332 }
333
334 parentCancel := rows.cancel
335 rows.cancel = func() {
336
337
338 d.close(rows.lasterr)
339 if parentCancel != nil {
340 parentCancel()
341 }
342 }
343 rows.closemu.Unlock()
344 return nil
345 }
346 }
347
348 var sv reflect.Value
349
350 switch d := dest.(type) {
351 case *string:
352 sv = reflect.ValueOf(src)
353 switch sv.Kind() {
354 case reflect.Bool,
355 reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
356 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
357 reflect.Float32, reflect.Float64:
358 *d = asString(src)
359 return nil
360 }
361 case *[]byte:
362 sv = reflect.ValueOf(src)
363 if b, ok := asBytes(nil, sv); ok {
364 *d = b
365 return nil
366 }
367 case *RawBytes:
368 sv = reflect.ValueOf(src)
369 if b, ok := asBytes([]byte(*d)[:0], sv); ok {
370 *d = RawBytes(b)
371 return nil
372 }
373 case *bool:
374 bv, err := driver.Bool.ConvertValue(src)
375 if err == nil {
376 *d = bv.(bool)
377 }
378 return err
379 case *any:
380 *d = src
381 return nil
382 }
383
384 if scanner, ok := dest.(Scanner); ok {
385 return scanner.Scan(src)
386 }
387
388 dpv := reflect.ValueOf(dest)
389 if dpv.Kind() != reflect.Pointer {
390 return errors.New("destination not a pointer")
391 }
392 if dpv.IsNil() {
393 return errNilPtr
394 }
395
396 if !sv.IsValid() {
397 sv = reflect.ValueOf(src)
398 }
399
400 dv := reflect.Indirect(dpv)
401 if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
402 switch b := src.(type) {
403 case []byte:
404 dv.Set(reflect.ValueOf(bytes.Clone(b)))
405 default:
406 dv.Set(sv)
407 }
408 return nil
409 }
410
411 if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
412 dv.Set(sv.Convert(dv.Type()))
413 return nil
414 }
415
416
417
418
419
420
421 switch dv.Kind() {
422 case reflect.Pointer:
423 if src == nil {
424 dv.SetZero()
425 return nil
426 }
427 dv.Set(reflect.New(dv.Type().Elem()))
428 return convertAssignRows(dv.Interface(), src, rows)
429 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
430 if src == nil {
431 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
432 }
433 s := asString(src)
434 i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
435 if err != nil {
436 err = strconvErr(err)
437 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
438 }
439 dv.SetInt(i64)
440 return nil
441 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
442 if src == nil {
443 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
444 }
445 s := asString(src)
446 u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
447 if err != nil {
448 err = strconvErr(err)
449 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
450 }
451 dv.SetUint(u64)
452 return nil
453 case reflect.Float32, reflect.Float64:
454 if src == nil {
455 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
456 }
457 s := asString(src)
458 f64, err := strconv.ParseFloat(s, dv.Type().Bits())
459 if err != nil {
460 err = strconvErr(err)
461 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
462 }
463 dv.SetFloat(f64)
464 return nil
465 case reflect.String:
466 if src == nil {
467 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
468 }
469 switch v := src.(type) {
470 case string:
471 dv.SetString(v)
472 return nil
473 case []byte:
474 dv.SetString(string(v))
475 return nil
476 }
477 }
478
479 return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
480 }
481
482 func strconvErr(err error) error {
483 if ne, ok := err.(*strconv.NumError); ok {
484 return ne.Err
485 }
486 return err
487 }
488
489 func asString(src any) string {
490 switch v := src.(type) {
491 case string:
492 return v
493 case []byte:
494 return string(v)
495 }
496 rv := reflect.ValueOf(src)
497 switch rv.Kind() {
498 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
499 return strconv.FormatInt(rv.Int(), 10)
500 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
501 return strconv.FormatUint(rv.Uint(), 10)
502 case reflect.Float64:
503 return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
504 case reflect.Float32:
505 return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
506 case reflect.Bool:
507 return strconv.FormatBool(rv.Bool())
508 }
509 return fmt.Sprintf("%v", src)
510 }
511
512 func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
513 switch rv.Kind() {
514 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
515 return strconv.AppendInt(buf, rv.Int(), 10), true
516 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
517 return strconv.AppendUint(buf, rv.Uint(), 10), true
518 case reflect.Float32:
519 return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
520 case reflect.Float64:
521 return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
522 case reflect.Bool:
523 return strconv.AppendBool(buf, rv.Bool()), true
524 case reflect.String:
525 s := rv.String()
526 return append(buf, s...), true
527 }
528 return
529 }
530
531 var valuerReflectType = reflect.TypeFor[driver.Valuer]()
532
533
534
535
536
537
538
539
540
541
542
543
544 func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
545 if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer &&
546 rv.IsNil() &&
547 rv.Type().Elem().Implements(valuerReflectType) {
548 return nil, nil
549 }
550 return vr.Value()
551 }
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574 type decimal interface {
575 decimalDecompose
576 decimalCompose
577 }
578
579 type decimalDecompose interface {
580
581
582
583 Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
584 }
585
586 type decimalCompose interface {
587
588
589 Compose(form byte, negative bool, coefficient []byte, exponent int32) error
590 }
591
View as plain text