...
1
2
3
4
5 package ssh
6
7
8
9 import (
10 "crypto/rand"
11 "fmt"
12 "reflect"
13 "sync"
14 "testing"
15 )
16
17
18
19
20
21 func TestKexes(t *testing.T) {
22 type kexResultErr struct {
23 result *kexResult
24 err error
25 }
26
27 for name, kex := range kexAlgoMap {
28 t.Run(name, func(t *testing.T) {
29 wg := sync.WaitGroup{}
30 for i := 0; i < 3; i++ {
31 wg.Add(1)
32 go func() {
33 defer wg.Done()
34 a, b := memPipe()
35
36 s := make(chan kexResultErr, 1)
37 c := make(chan kexResultErr, 1)
38 var magics handshakeMagics
39 go func() {
40 r, e := kex.Client(a, rand.Reader, &magics)
41 a.Close()
42 c <- kexResultErr{r, e}
43 }()
44 go func() {
45 r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"].(AlgorithmSigner), testSigners["ecdsa"].PublicKey().Type())
46 b.Close()
47 s <- kexResultErr{r, e}
48 }()
49
50 clientRes := <-c
51 serverRes := <-s
52 if clientRes.err != nil {
53 t.Errorf("client: %v", clientRes.err)
54 }
55 if serverRes.err != nil {
56 t.Errorf("server: %v", serverRes.err)
57 }
58 if !reflect.DeepEqual(clientRes.result, serverRes.result) {
59 t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result)
60 }
61 }()
62 }
63 wg.Wait()
64 })
65 }
66 }
67
68 func BenchmarkKexes(b *testing.B) {
69 type kexResultErr struct {
70 result *kexResult
71 err error
72 }
73
74 for name, kex := range kexAlgoMap {
75 b.Run(name, func(b *testing.B) {
76 for i := 0; i < b.N; i++ {
77 t1, t2 := memPipe()
78
79 s := make(chan kexResultErr, 1)
80 c := make(chan kexResultErr, 1)
81 var magics handshakeMagics
82
83 go func() {
84 r, e := kex.Client(t1, rand.Reader, &magics)
85 t1.Close()
86 c <- kexResultErr{r, e}
87 }()
88 go func() {
89 r, e := kex.Server(t2, rand.Reader, &magics, testSigners["ecdsa"].(AlgorithmSigner), testSigners["ecdsa"].PublicKey().Type())
90 t2.Close()
91 s <- kexResultErr{r, e}
92 }()
93
94 clientRes := <-c
95 serverRes := <-s
96
97 if clientRes.err != nil {
98 panic(fmt.Sprintf("client: %v", clientRes.err))
99 }
100 if serverRes.err != nil {
101 panic(fmt.Sprintf("server: %v", serverRes.err))
102 }
103 }
104 })
105 }
106 }
107
View as plain text