commit 28a768458bf1d595e4a98549e750b292a79737b5 Author: Mahno Date: Mon Apr 14 22:09:38 2025 +0800 add files diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c14a297 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +go.sum \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..11b0683 --- /dev/null +++ b/go.mod @@ -0,0 +1,10 @@ +module smtp-filter + +go 1.23.7 + +require github.com/sirupsen/logrus v1.9.3 + +require ( + github.com/fatedier/golib v0.5.1 + golang.org/x/sys v0.19.0 // indirect +) diff --git a/main.go b/main.go new file mode 100644 index 0000000..f1328c4 --- /dev/null +++ b/main.go @@ -0,0 +1,242 @@ +package main + +import ( + "bufio" + "crypto/tls" + "io" + "net" + "os" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + listenAddr = ":2525" + targetAddr = "localhost:25" + allowDomain = "libertarian.dev" + timeout = 30 * time.Second + certFile = "/home/haswell/Projects/mailu-deploy/mailu/certs/cert.pem" + keyFile = "/home/haswell/Projects/mailu-deploy/mailu/certs/key.pem" +) + +var ( + serverTLSConf *tls.Config +) + +type timeoutConn struct { + conn net.Conn +} + +func (c timeoutConn) Read(buf []byte) (int, error) { + c.conn.SetDeadline(time.Now().Add(time.Duration(30) * time.Second)) + return c.conn.Read(buf) +} +func (c timeoutConn) Write(buf []byte) (int, error) { + c.conn.SetDeadline(time.Now().Add(time.Duration(30) * time.Second)) + return c.conn.Write(buf) +} + +type smtpConnection struct { + client timeoutConn + target timeoutConn + reader *bufio.Reader + clientAddr string + isRcptValid bool +} + +func newSMTPConnection(clientConn net.Conn) (*smtpConnection, error) { + targetConn, err := net.Dial("tcp", targetAddr) + if err != nil { + return nil, err + } + + return &smtpConnection{ + client: timeoutConn{clientConn}, + target: timeoutConn{targetConn}, + reader: bufio.NewReader(clientConn), + clientAddr: clientConn.RemoteAddr().String(), + isRcptValid: false, + }, nil +} + +func (s *smtpConnection) close() { + s.client.conn.Close() + s.target.conn.Close() +} + +func (s *smtpConnection) upgradeToTLS() error { + // Upgrade client connection + if _, err := s.client.Write([]byte("220 Ready to start TLS\r\n")); err != nil { + return err + } + s.client.conn = tls.Server(s.client.conn, serverTLSConf) + + // Reset reader for new TLS connection + s.reader = bufio.NewReader(s.client) + return nil +} + +func (s *smtpConnection) readCommand() (string, error) { + line, err := s.reader.ReadString('\n') + if err != nil { + return "", err + } + return line, nil +} + +func handleConnection(clientConn net.Conn) { + smtp, err := newSMTPConnection(clientConn) + if err != nil { + log.Errorf("Failed to create SMTP connection: %v", err) + clientConn.Close() + return + } + defer smtp.close() + + log.Infof("New connection from %s", smtp.clientAddr) + + // Forward initial greeting + smtp.target.Write([]byte("")) + response := make([]byte, 1024) + n, err := smtp.target.Read(response) + if err != nil { + log.Errorf("Failed to get greeting: %v", err) + return + } + smtp.client.Write([]byte(response[:n])) + for { + if !smtp.isRcptValid { + line, err := smtp.readCommand() + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + log.Warnf("Client timeout from %s: %v", smtp.clientAddr, err) + } else { + log.Errorf("Error reading from client %s: %v", smtp.clientAddr, err) + } + return + } + + cmd := strings.ToUpper(strings.TrimSpace(line)) + log.Debugf("Received command: %s", cmd) + + // Handle STARTTLS before anything else + if strings.HasPrefix(cmd, "STARTTLS") { + + if err := smtp.upgradeToTLS(); err != nil { + log.Errorf("Failed to send STARTTLS response: %v", err) + return + } + log.Infof("Connection upgraded to TLS for %s", smtp.clientAddr) + continue + } + + // Handle RCPT TO validation inside TLS if needed + if strings.HasPrefix(cmd, "RCPT TO:") { + parts := strings.Split(strings.ToLower(line), ":") + if len(parts) != 2 { + smtp.client.Write([]byte("501 Syntax error in parameters or arguments\r\n")) + return + } + + email := strings.Trim(strings.TrimSpace(parts[1]), "<>") + domain := strings.Split(email, "@") + if len(domain) != 2 { + smtp.client.Write([]byte("501 Syntax error in parameters or arguments\r\n")) + return + } + + if domain[1] != allowDomain { + log.Warnf("Rejected recipient from %s: domain %s not allowed", smtp.clientAddr, domain[1]) + smtp.client.Write([]byte("554 Domain not allowed\r\n")) + return + } + smtp.isRcptValid = true + } + + // Check DATA command + if strings.HasPrefix(cmd, "DATA") && !smtp.isRcptValid { + log.Warnf("Disconnecting %s - attempted DATA without valid recipient", smtp.clientAddr) + return + } + //Otherwise Forward to target and get response + smtp.target.Write([]byte(line)) + n, err := smtp.target.Read(response) + if err != nil { + log.Errorf("Failed to read from target: %v", err) // Added log output + return + } + smtp.client.Write([]byte(response[:n])) + } else { + for smtp.reader.Buffered() > 0 { + line, _ := smtp.readCommand() + smtp.target.Write([]byte(line)) + n, err := smtp.target.Read(response) + if err != nil { + log.Errorf("Failed to read from target: %v", err) // Added log output + return + } + smtp.client.Write([]byte(response[:n])) + } + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + if _, err := io.Copy(smtp.target, smtp.client); err != nil { + log.Errorf("Error forwarding client to server: %v", err) + } + }() + + // 从服务器到客户端 + go func() { + defer wg.Done() + if _, err := io.Copy(smtp.client, smtp.target); err != nil { + log.Errorf("Error forwarding server to client: %v", err) + } + }() + + wg.Wait() + log.Infof("Connection from %s closed", smtp.client.conn.RemoteAddr()) + return + } + } +} + +func main() { + // Configure logging + log.SetLevel(log.InfoLevel) + keyLogFile, _ := os.OpenFile("/tmp/tlskey.log", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + // Load certificate + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + log.Fatalf("Failed to load certificate: %v", err) + return + } + + // TLS config for server + serverTLSConf = &tls.Config{ + Certificates: []tls.Certificate{cert}, + KeyLogWriter: keyLogFile, + } + + // Create TCP listener + listener, err := net.Listen("tcp", listenAddr) + if err != nil { + log.Fatalf("Failed to start server: %v", err) + } + defer listener.Close() + + log.Infof("SMTP filter listening on %s", listenAddr) + + // Accept connections + for { + conn, err := listener.Accept() + if err != nil { + log.Errorf("Failed to accept connection: %v", err) + continue + } + go handleConnection(conn) + } +}