...

Source file src/github.com/go-mail/mail/smtp_test.go

Documentation: github.com/go-mail/mail

     1  package mail
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/tls"
     6  	"io"
     7  	"net"
     8  	"net/smtp"
     9  	"reflect"
    10  	"testing"
    11  	"time"
    12  )
    13  
    14  const (
    15  	testPort    = 587
    16  	testSSLPort = 465
    17  )
    18  
    19  var (
    20  	testConn    = &net.TCPConn{}
    21  	testTLSConn = tls.Client(testConn, &tls.Config{InsecureSkipVerify: true})
    22  	testConfig  = &tls.Config{InsecureSkipVerify: true}
    23  	testAuth    = smtp.PlainAuth("", testUser, testPwd, testHost)
    24  )
    25  
    26  func TestDialer(t *testing.T) {
    27  	d := NewDialer(testHost, testPort, "user", "pwd")
    28  	testSendMail(t, d, []string{
    29  		"Extension STARTTLS",
    30  		"StartTLS",
    31  		"Extension AUTH",
    32  		"Auth",
    33  		"Mail " + testFrom,
    34  		"Rcpt " + testTo1,
    35  		"Rcpt " + testTo2,
    36  		"Data",
    37  		"Write message",
    38  		"Close writer",
    39  		"Quit",
    40  		"Close",
    41  	})
    42  }
    43  
    44  func TestDialerSSL(t *testing.T) {
    45  	d := NewDialer(testHost, testSSLPort, "user", "pwd")
    46  	testSendMail(t, d, []string{
    47  		"Extension AUTH",
    48  		"Auth",
    49  		"Mail " + testFrom,
    50  		"Rcpt " + testTo1,
    51  		"Rcpt " + testTo2,
    52  		"Data",
    53  		"Write message",
    54  		"Close writer",
    55  		"Quit",
    56  		"Close",
    57  	})
    58  }
    59  
    60  func TestDialerConfig(t *testing.T) {
    61  	d := NewDialer(testHost, testPort, "user", "pwd")
    62  	d.LocalName = "test"
    63  	d.TLSConfig = testConfig
    64  	testSendMail(t, d, []string{
    65  		"Hello test",
    66  		"Extension STARTTLS",
    67  		"StartTLS",
    68  		"Extension AUTH",
    69  		"Auth",
    70  		"Mail " + testFrom,
    71  		"Rcpt " + testTo1,
    72  		"Rcpt " + testTo2,
    73  		"Data",
    74  		"Write message",
    75  		"Close writer",
    76  		"Quit",
    77  		"Close",
    78  	})
    79  }
    80  
    81  func TestDialerSSLConfig(t *testing.T) {
    82  	d := NewDialer(testHost, testSSLPort, "user", "pwd")
    83  	d.LocalName = "test"
    84  	d.TLSConfig = testConfig
    85  	testSendMail(t, d, []string{
    86  		"Hello test",
    87  		"Extension AUTH",
    88  		"Auth",
    89  		"Mail " + testFrom,
    90  		"Rcpt " + testTo1,
    91  		"Rcpt " + testTo2,
    92  		"Data",
    93  		"Write message",
    94  		"Close writer",
    95  		"Quit",
    96  		"Close",
    97  	})
    98  }
    99  
   100  func TestDialerNoStartTLS(t *testing.T) {
   101  	d := NewDialer(testHost, testPort, "user", "pwd")
   102  	d.StartTLSPolicy = NoStartTLS
   103  	testSendMail(t, d, []string{
   104  		"Extension AUTH",
   105  		"Auth",
   106  		"Mail " + testFrom,
   107  		"Rcpt " + testTo1,
   108  		"Rcpt " + testTo2,
   109  		"Data",
   110  		"Write message",
   111  		"Close writer",
   112  		"Quit",
   113  		"Close",
   114  	})
   115  }
   116  
   117  func TestDialerOpportunisticStartTLS(t *testing.T) {
   118  	d := NewDialer(testHost, testPort, "user", "pwd")
   119  	d.StartTLSPolicy = OpportunisticStartTLS
   120  	testSendMail(t, d, []string{
   121  		"Extension STARTTLS",
   122  		"StartTLS",
   123  		"Extension AUTH",
   124  		"Auth",
   125  		"Mail " + testFrom,
   126  		"Rcpt " + testTo1,
   127  		"Rcpt " + testTo2,
   128  		"Data",
   129  		"Write message",
   130  		"Close writer",
   131  		"Quit",
   132  		"Close",
   133  	})
   134  
   135  	if OpportunisticStartTLS != 0 {
   136  		t.Errorf("OpportunisticStartTLS: expected 0, got %d",
   137  			OpportunisticStartTLS)
   138  	}
   139  }
   140  
   141  func TestDialerOpportunisticStartTLSUnsupported(t *testing.T) {
   142  	d := NewDialer(testHost, testPort, "user", "pwd")
   143  	d.StartTLSPolicy = OpportunisticStartTLS
   144  	testSendMailStartTLSUnsupported(t, d, []string{
   145  		"Extension STARTTLS",
   146  		"Extension AUTH",
   147  		"Auth",
   148  		"Mail " + testFrom,
   149  		"Rcpt " + testTo1,
   150  		"Rcpt " + testTo2,
   151  		"Data",
   152  		"Write message",
   153  		"Close writer",
   154  		"Quit",
   155  		"Close",
   156  	})
   157  }
   158  
   159  func TestDialerMandatoryStartTLS(t *testing.T) {
   160  	d := NewDialer(testHost, testPort, "user", "pwd")
   161  	d.StartTLSPolicy = MandatoryStartTLS
   162  	testSendMail(t, d, []string{
   163  		"Extension STARTTLS",
   164  		"StartTLS",
   165  		"Extension AUTH",
   166  		"Auth",
   167  		"Mail " + testFrom,
   168  		"Rcpt " + testTo1,
   169  		"Rcpt " + testTo2,
   170  		"Data",
   171  		"Write message",
   172  		"Close writer",
   173  		"Quit",
   174  		"Close",
   175  	})
   176  }
   177  
   178  func TestDialerMandatoryStartTLSUnsupported(t *testing.T) {
   179  	d := NewDialer(testHost, testPort, "user", "pwd")
   180  	d.StartTLSPolicy = MandatoryStartTLS
   181  
   182  	testClient := &mockClient{
   183  		t:        t,
   184  		addr:     addr(d.Host, d.Port),
   185  		config:   d.TLSConfig,
   186  		startTLS: false,
   187  		timeout:  true,
   188  	}
   189  
   190  	err := doTestSendMail(t, d, testClient, []string{
   191  		"Extension STARTTLS",
   192  	})
   193  
   194  	if _, ok := err.(StartTLSUnsupportedError); !ok {
   195  		t.Errorf("expected StartTLSUnsupportedError, but got: %s",
   196  			reflect.TypeOf(err).Name())
   197  	}
   198  
   199  	expected := "gomail: MandatoryStartTLS required, " +
   200  		"but SMTP server does not support STARTTLS"
   201  	if err.Error() != expected {
   202  		t.Errorf("expected %s, but got: %s", expected, err)
   203  	}
   204  }
   205  
   206  func TestDialerNoAuth(t *testing.T) {
   207  	d := &Dialer{
   208  		Host: testHost,
   209  		Port: testPort,
   210  	}
   211  	testSendMail(t, d, []string{
   212  		"Extension STARTTLS",
   213  		"StartTLS",
   214  		"Mail " + testFrom,
   215  		"Rcpt " + testTo1,
   216  		"Rcpt " + testTo2,
   217  		"Data",
   218  		"Write message",
   219  		"Close writer",
   220  		"Quit",
   221  		"Close",
   222  	})
   223  }
   224  
   225  func TestDialerTimeout(t *testing.T) {
   226  	d := &Dialer{
   227  		Host:         testHost,
   228  		Port:         testPort,
   229  		RetryFailure: true,
   230  	}
   231  	testSendMailTimeout(t, d, []string{
   232  		"Extension STARTTLS",
   233  		"StartTLS",
   234  		"Mail " + testFrom,
   235  		"Extension STARTTLS",
   236  		"StartTLS",
   237  		"Mail " + testFrom,
   238  		"Rcpt " + testTo1,
   239  		"Rcpt " + testTo2,
   240  		"Data",
   241  		"Write message",
   242  		"Close writer",
   243  		"Quit",
   244  		"Close",
   245  	})
   246  }
   247  
   248  func TestDialerTimeoutNoRetry(t *testing.T) {
   249  	d := &Dialer{
   250  		Host:         testHost,
   251  		Port:         testPort,
   252  		RetryFailure: false,
   253  	}
   254  	testClient := &mockClient{
   255  		t:        t,
   256  		addr:     addr(d.Host, d.Port),
   257  		config:   d.TLSConfig,
   258  		startTLS: true,
   259  		timeout:  true,
   260  	}
   261  
   262  	err := doTestSendMail(t, d, testClient, []string{
   263  		"Extension STARTTLS",
   264  		"StartTLS",
   265  		"Mail " + testFrom,
   266  		"Quit",
   267  	})
   268  
   269  	if err.Error() != "gomail: could not send email 1: EOF" {
   270  		t.Error("expected to have got EOF, but got:", err)
   271  	}
   272  }
   273  
   274  type mockClient struct {
   275  	t        *testing.T
   276  	i        int
   277  	want     []string
   278  	addr     string
   279  	config   *tls.Config
   280  	startTLS bool
   281  	timeout  bool
   282  }
   283  
   284  func (c *mockClient) Hello(localName string) error {
   285  	c.do("Hello " + localName)
   286  	return nil
   287  }
   288  
   289  func (c *mockClient) Extension(ext string) (bool, string) {
   290  	c.do("Extension " + ext)
   291  	ok := true
   292  	if ext == "STARTTLS" {
   293  		ok = c.startTLS
   294  	}
   295  	return ok, ""
   296  }
   297  
   298  func (c *mockClient) StartTLS(config *tls.Config) error {
   299  	assertConfig(c.t, config, c.config)
   300  	c.do("StartTLS")
   301  	return nil
   302  }
   303  
   304  func (c *mockClient) Auth(a smtp.Auth) error {
   305  	if !reflect.DeepEqual(a, testAuth) {
   306  		c.t.Errorf("Invalid auth, got %#v, want %#v", a, testAuth)
   307  	}
   308  	c.do("Auth")
   309  	return nil
   310  }
   311  
   312  func (c *mockClient) Mail(from string) error {
   313  	c.do("Mail " + from)
   314  	if c.timeout {
   315  		c.timeout = false
   316  		return io.EOF
   317  	}
   318  	return nil
   319  }
   320  
   321  func (c *mockClient) Rcpt(to string) error {
   322  	c.do("Rcpt " + to)
   323  	return nil
   324  }
   325  
   326  func (c *mockClient) Data() (io.WriteCloser, error) {
   327  	c.do("Data")
   328  	return &mockWriter{c: c, want: testMsg}, nil
   329  }
   330  
   331  func (c *mockClient) Quit() error {
   332  	c.do("Quit")
   333  	return nil
   334  }
   335  
   336  func (c *mockClient) Close() error {
   337  	c.do("Close")
   338  	return nil
   339  }
   340  
   341  func (c *mockClient) do(cmd string) {
   342  	if c.i >= len(c.want) {
   343  		c.t.Fatalf("Invalid command %q", cmd)
   344  	}
   345  
   346  	if cmd != c.want[c.i] {
   347  		c.t.Fatalf("Invalid command, got %q, want %q", cmd, c.want[c.i])
   348  	}
   349  	c.i++
   350  }
   351  
   352  type mockWriter struct {
   353  	want string
   354  	c    *mockClient
   355  	buf  bytes.Buffer
   356  }
   357  
   358  func (w *mockWriter) Write(p []byte) (int, error) {
   359  	if w.buf.Len() == 0 {
   360  		w.c.do("Write message")
   361  	}
   362  	w.buf.Write(p)
   363  	return len(p), nil
   364  }
   365  
   366  func (w *mockWriter) Close() error {
   367  	compareBodies(w.c.t, w.buf.String(), w.want)
   368  	w.c.do("Close writer")
   369  	return nil
   370  }
   371  
   372  func testSendMail(t *testing.T, d *Dialer, want []string) {
   373  	testClient := &mockClient{
   374  		t:        t,
   375  		addr:     addr(d.Host, d.Port),
   376  		config:   d.TLSConfig,
   377  		startTLS: true,
   378  		timeout:  false,
   379  	}
   380  
   381  	if err := doTestSendMail(t, d, testClient, want); err != nil {
   382  		t.Error(err)
   383  	}
   384  }
   385  
   386  func testSendMailStartTLSUnsupported(t *testing.T, d *Dialer, want []string) {
   387  	testClient := &mockClient{
   388  		t:        t,
   389  		addr:     addr(d.Host, d.Port),
   390  		config:   d.TLSConfig,
   391  		startTLS: false,
   392  		timeout:  false,
   393  	}
   394  
   395  	if err := doTestSendMail(t, d, testClient, want); err != nil {
   396  		t.Error(err)
   397  	}
   398  }
   399  
   400  func testSendMailTimeout(t *testing.T, d *Dialer, want []string) {
   401  	testClient := &mockClient{
   402  		t:        t,
   403  		addr:     addr(d.Host, d.Port),
   404  		config:   d.TLSConfig,
   405  		startTLS: true,
   406  		timeout:  true,
   407  	}
   408  
   409  	if err := doTestSendMail(t, d, testClient, want); err != nil {
   410  		t.Error(err)
   411  	}
   412  }
   413  
   414  func doTestSendMail(t *testing.T, d *Dialer, testClient *mockClient, want []string) error {
   415  	testClient.want = want
   416  
   417  	NetDialTimeout = func(network, address string, d time.Duration) (net.Conn, error) {
   418  		if network != "tcp" {
   419  			t.Errorf("Invalid network, got %q, want tcp", network)
   420  		}
   421  		if address != testClient.addr {
   422  			t.Errorf("Invalid address, got %q, want %q",
   423  				address, testClient.addr)
   424  		}
   425  		return testConn, nil
   426  	}
   427  
   428  	tlsClient = func(conn net.Conn, config *tls.Config) *tls.Conn {
   429  		if conn != testConn {
   430  			t.Errorf("Invalid conn, got %#v, want %#v", conn, testConn)
   431  		}
   432  		assertConfig(t, config, testClient.config)
   433  		return testTLSConn
   434  	}
   435  
   436  	smtpNewClient = func(conn net.Conn, host string) (smtpClient, error) {
   437  		if host != testHost {
   438  			t.Errorf("Invalid host, got %q, want %q", host, testHost)
   439  		}
   440  		return testClient, nil
   441  	}
   442  
   443  	return d.DialAndSend(getTestMessage())
   444  }
   445  
   446  func assertConfig(t *testing.T, got, want *tls.Config) {
   447  	if want == nil {
   448  		want = &tls.Config{ServerName: testHost}
   449  	}
   450  	if got.ServerName != want.ServerName {
   451  		t.Errorf("Invalid field ServerName in config, got %q, want %q", got.ServerName, want.ServerName)
   452  	}
   453  	if got.InsecureSkipVerify != want.InsecureSkipVerify {
   454  		t.Errorf("Invalid field InsecureSkipVerify in config, got %v, want %v", got.InsecureSkipVerify, want.InsecureSkipVerify)
   455  	}
   456  }
   457  

View as plain text