...

Source file src/golang.org/x/crypto/ssh/test/test_unix_test.go

Documentation: golang.org/x/crypto/ssh/test

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || plan9 || solaris
     6  
     7  package test
     8  
     9  // functional test harness for unix.
    10  
    11  import (
    12  	"bytes"
    13  	"crypto/rand"
    14  	"encoding/base64"
    15  	"fmt"
    16  	"log"
    17  	"net"
    18  	"os"
    19  	"os/exec"
    20  	"os/user"
    21  	"path/filepath"
    22  	"testing"
    23  	"text/template"
    24  
    25  	"golang.org/x/crypto/internal/testenv"
    26  	"golang.org/x/crypto/ssh"
    27  	"golang.org/x/crypto/ssh/testdata"
    28  )
    29  
    30  const (
    31  	defaultSshdConfig = `
    32  Protocol 2
    33  Banner {{.Dir}}/banner
    34  HostKey {{.Dir}}/id_rsa
    35  HostKey {{.Dir}}/id_dsa
    36  HostKey {{.Dir}}/id_ecdsa
    37  HostCertificate {{.Dir}}/id_rsa-sha2-512-cert.pub
    38  Pidfile {{.Dir}}/sshd.pid
    39  #UsePrivilegeSeparation no
    40  KeyRegenerationInterval 3600
    41  ServerKeyBits 768
    42  SyslogFacility AUTH
    43  LogLevel DEBUG2
    44  LoginGraceTime 120
    45  PermitRootLogin no
    46  StrictModes no
    47  RSAAuthentication yes
    48  PubkeyAuthentication yes
    49  AuthorizedKeysFile	{{.Dir}}/authorized_keys
    50  TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub
    51  IgnoreRhosts yes
    52  RhostsRSAAuthentication no
    53  HostbasedAuthentication no
    54  PubkeyAcceptedKeyTypes=*
    55  `
    56  	multiAuthSshdConfigTail = `
    57  UsePAM yes
    58  PasswordAuthentication yes
    59  ChallengeResponseAuthentication yes
    60  AuthenticationMethods {{.AuthMethods}}
    61  `
    62  )
    63  
    64  var configTmpl = map[string]*template.Template{
    65  	"default":   template.Must(template.New("").Parse(defaultSshdConfig)),
    66  	"MultiAuth": template.Must(template.New("").Parse(defaultSshdConfig + multiAuthSshdConfigTail))}
    67  
    68  type server struct {
    69  	t          *testing.T
    70  	configfile string
    71  
    72  	testUser     string // test username for sshd
    73  	testPasswd   string // test password for sshd
    74  	sshdTestPwSo string // dynamic library to inject a custom password into sshd
    75  
    76  	lastDialConn net.Conn
    77  }
    78  
    79  func username() string {
    80  	var username string
    81  	if user, err := user.Current(); err == nil {
    82  		username = user.Username
    83  	} else {
    84  		// user.Current() currently requires cgo. If an error is
    85  		// returned attempt to get the username from the environment.
    86  		log.Printf("user.Current: %v; falling back on $USER", err)
    87  		username = os.Getenv("USER")
    88  	}
    89  	if username == "" {
    90  		panic("Unable to get username")
    91  	}
    92  	return username
    93  }
    94  
    95  type storedHostKey struct {
    96  	// keys map from an algorithm string to binary key data.
    97  	keys map[string][]byte
    98  
    99  	// checkCount counts the Check calls. Used for testing
   100  	// rekeying.
   101  	checkCount int
   102  }
   103  
   104  func (k *storedHostKey) Add(key ssh.PublicKey) {
   105  	if k.keys == nil {
   106  		k.keys = map[string][]byte{}
   107  	}
   108  	k.keys[key.Type()] = key.Marshal()
   109  }
   110  
   111  func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error {
   112  	k.checkCount++
   113  	algo := key.Type()
   114  
   115  	if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 {
   116  		return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo])
   117  	}
   118  	return nil
   119  }
   120  
   121  func hostKeyDB() *storedHostKey {
   122  	keyChecker := &storedHostKey{}
   123  	keyChecker.Add(testPublicKeys["ecdsa"])
   124  	keyChecker.Add(testPublicKeys["rsa"])
   125  	keyChecker.Add(testPublicKeys["dsa"])
   126  	return keyChecker
   127  }
   128  
   129  func clientConfig() *ssh.ClientConfig {
   130  	config := &ssh.ClientConfig{
   131  		User: username(),
   132  		Auth: []ssh.AuthMethod{
   133  			ssh.PublicKeys(testSigners["user"]),
   134  		},
   135  		HostKeyCallback: hostKeyDB().Check,
   136  		HostKeyAlgorithms: []string{ // by default, don't allow certs as this affects the hostKeyDB checker
   137  			ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521,
   138  			ssh.KeyAlgoRSA, ssh.KeyAlgoDSA,
   139  			ssh.KeyAlgoED25519,
   140  		},
   141  	}
   142  	return config
   143  }
   144  
   145  // unixConnection creates two halves of a connected net.UnixConn.  It
   146  // is used for connecting the Go SSH client with sshd without opening
   147  // ports.
   148  func unixConnection() (*net.UnixConn, *net.UnixConn, error) {
   149  	dir, err := os.MkdirTemp("", "unixConnection")
   150  	if err != nil {
   151  		return nil, nil, err
   152  	}
   153  	defer os.Remove(dir)
   154  
   155  	addr := filepath.Join(dir, "ssh")
   156  	listener, err := net.Listen("unix", addr)
   157  	if err != nil {
   158  		return nil, nil, err
   159  	}
   160  	defer listener.Close()
   161  	c1, err := net.Dial("unix", addr)
   162  	if err != nil {
   163  		return nil, nil, err
   164  	}
   165  
   166  	c2, err := listener.Accept()
   167  	if err != nil {
   168  		c1.Close()
   169  		return nil, nil, err
   170  	}
   171  
   172  	return c1.(*net.UnixConn), c2.(*net.UnixConn), nil
   173  }
   174  
   175  func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
   176  	return s.TryDialWithAddr(config, "")
   177  }
   178  
   179  // addr is the user specified host:port. While we don't actually dial it,
   180  // we need to know this for host key matching
   181  func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (client *ssh.Client, err error) {
   182  	sshd, err := exec.LookPath("sshd")
   183  	if err != nil {
   184  		s.t.Skipf("skipping test: %v", err)
   185  	}
   186  
   187  	c1, c2, err := unixConnection()
   188  	if err != nil {
   189  		s.t.Fatalf("unixConnection: %v", err)
   190  	}
   191  	defer func() {
   192  		// Close c2 after we've started the sshd command so that it won't prevent c1
   193  		// from returning EOF when the sshd command exits.
   194  		c2.Close()
   195  
   196  		// Leave c1 open if we're returning a client that wraps it.
   197  		// (The client is responsible for closing it.)
   198  		// Otherwise, close it to free up the socket.
   199  		if client == nil {
   200  			c1.Close()
   201  		}
   202  	}()
   203  
   204  	f, err := c2.File()
   205  	if err != nil {
   206  		s.t.Fatalf("UnixConn.File: %v", err)
   207  	}
   208  	defer f.Close()
   209  
   210  	cmd := testenv.Command(s.t, sshd, "-f", s.configfile, "-i", "-e")
   211  	cmd.Stdin = f
   212  	cmd.Stdout = f
   213  	cmd.Stderr = new(bytes.Buffer)
   214  
   215  	if s.sshdTestPwSo != "" {
   216  		if s.testUser == "" {
   217  			s.t.Fatal("user missing from sshd_test_pw.so config")
   218  		}
   219  		if s.testPasswd == "" {
   220  			s.t.Fatal("password missing from sshd_test_pw.so config")
   221  		}
   222  		cmd.Env = append(os.Environ(),
   223  			fmt.Sprintf("LD_PRELOAD=%s", s.sshdTestPwSo),
   224  			fmt.Sprintf("TEST_USER=%s", s.testUser),
   225  			fmt.Sprintf("TEST_PASSWD=%s", s.testPasswd))
   226  	}
   227  
   228  	if err := cmd.Start(); err != nil {
   229  		s.t.Fatalf("s.cmd.Start: %v", err)
   230  	}
   231  	s.lastDialConn = c1
   232  	s.t.Cleanup(func() {
   233  		// Don't check for errors; if it fails it's most
   234  		// likely "os: process already finished", and we don't
   235  		// care about that. Use os.Interrupt, so child
   236  		// processes are killed too.
   237  		cmd.Process.Signal(os.Interrupt)
   238  		cmd.Wait()
   239  		if s.t.Failed() || testing.Verbose() {
   240  			// log any output from sshd process
   241  			s.t.Logf("sshd:\n%s", cmd.Stderr)
   242  		}
   243  	})
   244  
   245  	conn, chans, reqs, err := ssh.NewClientConn(c1, addr, config)
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  	return ssh.NewClient(conn, chans, reqs), nil
   250  }
   251  
   252  func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client {
   253  	conn, err := s.TryDial(config)
   254  	if err != nil {
   255  		s.t.Fatalf("ssh.Client: %v", err)
   256  	}
   257  	return conn
   258  }
   259  
   260  func writeFile(path string, contents []byte) {
   261  	f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
   262  	if err != nil {
   263  		panic(err)
   264  	}
   265  	defer f.Close()
   266  	if _, err := f.Write(contents); err != nil {
   267  		panic(err)
   268  	}
   269  }
   270  
   271  // generate random password
   272  func randomPassword() (string, error) {
   273  	b := make([]byte, 12)
   274  	_, err := rand.Read(b)
   275  	if err != nil {
   276  		return "", err
   277  	}
   278  	return base64.RawURLEncoding.EncodeToString(b), nil
   279  }
   280  
   281  // setTestPassword is used for setting user and password data for sshd_test_pw.so
   282  // This function also checks that ./sshd_test_pw.so exists and if not calls s.t.Skip()
   283  func (s *server) setTestPassword(user, passwd string) error {
   284  	wd, _ := os.Getwd()
   285  	wrapper := filepath.Join(wd, "sshd_test_pw.so")
   286  	if _, err := os.Stat(wrapper); err != nil {
   287  		s.t.Skip(fmt.Errorf("sshd_test_pw.so is not available"))
   288  		return err
   289  	}
   290  
   291  	s.sshdTestPwSo = wrapper
   292  	s.testUser = user
   293  	s.testPasswd = passwd
   294  	return nil
   295  }
   296  
   297  // newServer returns a new mock ssh server.
   298  func newServer(t *testing.T) *server {
   299  	return newServerForConfig(t, "default", map[string]string{})
   300  }
   301  
   302  // newServerForConfig returns a new mock ssh server.
   303  func newServerForConfig(t *testing.T, config string, configVars map[string]string) *server {
   304  	if testing.Short() {
   305  		t.Skip("skipping test due to -short")
   306  	}
   307  	u, err := user.Current()
   308  	if err != nil {
   309  		t.Fatalf("user.Current: %v", err)
   310  	}
   311  	uname := u.Name
   312  	if uname == "" {
   313  		// Check the value of u.Username as u.Name
   314  		// can be "" on some OSes like AIX.
   315  		uname = u.Username
   316  	}
   317  	if uname == "root" {
   318  		t.Skip("skipping test because current user is root")
   319  	}
   320  	dir, err := os.MkdirTemp("", "sshtest")
   321  	if err != nil {
   322  		t.Fatal(err)
   323  	}
   324  	f, err := os.Create(filepath.Join(dir, "sshd_config"))
   325  	if err != nil {
   326  		t.Fatal(err)
   327  	}
   328  	if _, ok := configTmpl[config]; ok == false {
   329  		t.Fatal(fmt.Errorf("Invalid server config '%s'", config))
   330  	}
   331  	configVars["Dir"] = dir
   332  	err = configTmpl[config].Execute(f, configVars)
   333  	if err != nil {
   334  		t.Fatal(err)
   335  	}
   336  	f.Close()
   337  
   338  	writeFile(filepath.Join(dir, "banner"), []byte("Server Banner"))
   339  
   340  	for k, v := range testdata.PEMBytes {
   341  		filename := "id_" + k
   342  		writeFile(filepath.Join(dir, filename), v)
   343  		writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
   344  	}
   345  
   346  	for k, v := range testdata.SSHCertificates {
   347  		filename := "id_" + k + "-cert.pub"
   348  		writeFile(filepath.Join(dir, filename), v)
   349  	}
   350  
   351  	var authkeys bytes.Buffer
   352  	for k := range testdata.PEMBytes {
   353  		authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k]))
   354  	}
   355  	writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes())
   356  	t.Cleanup(func() {
   357  		if err := os.RemoveAll(dir); err != nil {
   358  			t.Error(err)
   359  		}
   360  	})
   361  
   362  	return &server{
   363  		t:          t,
   364  		configfile: f.Name(),
   365  	}
   366  }
   367  
   368  func newTempSocket(t *testing.T) (string, func()) {
   369  	dir, err := os.MkdirTemp("", "socket")
   370  	if err != nil {
   371  		t.Fatal(err)
   372  	}
   373  	deferFunc := func() { os.RemoveAll(dir) }
   374  	addr := filepath.Join(dir, "sock")
   375  	return addr, deferFunc
   376  }
   377  

View as plain text