smtp-filter/main.go
2025-04-15 18:27:54 +08:00

269 lines
6.9 KiB
Go

package main
import (
"bufio"
"crypto/tls"
"io"
"net"
"os"
"regexp"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
listenAddr = ":2525"
targetAddr = "localhost:25"
allowDomain = "libertarian.dev"
timeout = 30 * time.Second
certFile = "/home/haswell/Projects/mailu-deploy/mailu/certs/cert.pem"
keyFile = "/home/haswell/Projects/mailu-deploy/mailu/certs/key.pem"
)
var (
serverTLSConf *tls.Config
)
var re = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
type timeoutConn struct {
conn net.Conn
}
func (c timeoutConn) Read(buf []byte) (int, error) {
c.conn.SetDeadline(time.Now().Add(time.Duration(30) * time.Second))
return c.conn.Read(buf)
}
func (c timeoutConn) Write(buf []byte) (int, error) {
c.conn.SetDeadline(time.Now().Add(time.Duration(30) * time.Second))
return c.conn.Write(buf)
}
type smtpConnection struct {
client timeoutConn
target timeoutConn
reader *bufio.Reader
clientAddr string
isRcptValid bool
}
func newSMTPConnection(clientConn net.Conn) (*smtpConnection, error) {
targetConn, err := net.Dial("tcp", targetAddr)
if err != nil {
return nil, err
}
return &smtpConnection{
client: timeoutConn{clientConn},
target: timeoutConn{targetConn},
reader: bufio.NewReader(clientConn),
clientAddr: clientConn.RemoteAddr().String(),
isRcptValid: false,
}, nil
}
func (s *smtpConnection) close() {
s.client.conn.Close()
s.target.conn.Close()
}
func (s *smtpConnection) upgradeToTLS() error {
// Upgrade client connection
if _, err := s.client.Write([]byte("220 Ready to start TLS\r\n")); err != nil {
return err
}
s.client.conn = tls.Server(s.client.conn, serverTLSConf)
// Reset reader for new TLS connection
s.reader = bufio.NewReader(s.client)
return nil
}
func (s *smtpConnection) readCommand() (string, error) {
line, err := s.reader.ReadString('\n')
if err != nil {
return "", err
}
return line, nil
}
func handleConnection(clientConn net.Conn) {
smtp, err := newSMTPConnection(clientConn)
if err != nil {
log.Errorf("Failed to create SMTP connection: %v", err)
clientConn.Close()
return
}
defer smtp.close()
log.Infof("New connection from %s", smtp.clientAddr)
// Forward initial greeting
smtp.target.Write([]byte(""))
response := make([]byte, 1024)
n, err := smtp.target.Read(response)
if err != nil {
log.Errorf("Failed to get greeting: %v", err)
return
}
smtp.client.Write([]byte(response[:n]))
for {
if !smtp.isRcptValid {
line, err := smtp.readCommand()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
log.Warnf("Client timeout from %s: %v", smtp.clientAddr, err)
} else {
log.Errorf("Error reading from client %s: %v", smtp.clientAddr, err)
}
return
}
cmd := strings.ToUpper(strings.TrimSpace(line))
log.Debugf("Received command: %s", cmd)
// Handle STARTTLS before anything else
if strings.HasPrefix(cmd, "STARTTLS") {
if err := smtp.upgradeToTLS(); err != nil {
log.Errorf("Failed to send STARTTLS response: %v", err)
return
}
log.Infof("Connection upgraded to TLS for %s", smtp.clientAddr)
continue
}
if strings.HasPrefix(cmd, "AUTH") {
smtp.client.Write([]byte("554 5.7.1 Access denied\r\n"))
log.Warnf("Rejected recipient from %s: external auth denied", smtp.clientAddr)
return
}
// Handle RCPT TO validation inside TLS if needed
if strings.HasPrefix(cmd, "RCPT TO:") {
parts := strings.Split(strings.ToLower(line), ":")
if len(parts) != 2 {
smtp.client.Write([]byte("501 Syntax error in parameters or arguments\r\n"))
return
}
// email := strings.Trim(strings.TrimSpace(parts[1]), "<>")
// domain := strings.Split(email, "@")
// if len(domain) != 2 {
// smtp.client.Write([]byte("501 Syntax error in parameters or arguments\r\n"))
// return
// }
// fmt.Sscanf(parts[1], "<%s>%s", &email, &whatever)
// email := re.FindStringSubmatch(parts[1])
// if len(email) < 1 {
// log.Errorf("%s is not email address", email[1])
// return
// }
// domain := strings.Split(email[0], "@")
start := strings.Index(parts[1], "<")
end := strings.Index(parts[1], ">")
if start == -1 || end == -1 || start >= end {
smtp.client.Write([]byte("501 Syntax error in parameters or arguments\r\n"))
return
}
email := parts[1][start+1 : end]
domain := strings.Split(email, "@")
if len(domain) != 2 {
smtp.client.Write([]byte("501 Syntax error in parameters or arguments\r\n"))
return
}
if domain[1] != allowDomain {
log.Warnf("Rejected recipient from %s: domain %s not allowed", smtp.clientAddr, domain[1])
smtp.client.Write([]byte("554 Domain not allowed\r\n"))
return
}
smtp.isRcptValid = true
}
// Check DATA command
if strings.HasPrefix(cmd, "DATA") && !smtp.isRcptValid {
log.Warnf("Disconnecting %s - attempted DATA without valid recipient", smtp.clientAddr)
return
}
//Otherwise Forward to target and get response
smtp.target.Write([]byte(line))
n, err := smtp.target.Read(response)
if err != nil {
log.Errorf("Failed to read from target: %v", err) // Added log output
return
}
smtp.client.Write([]byte(response[:n]))
} else {
for smtp.reader.Buffered() > 0 {
line, _ := smtp.readCommand()
smtp.target.Write([]byte(line))
n, err := smtp.target.Read(response)
if err != nil {
log.Errorf("Failed to read from target: %v", err) // Added log output
return
}
smtp.client.Write([]byte(response[:n]))
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
if _, err := io.Copy(smtp.target, smtp.client); err != nil {
log.Errorf("Error forwarding client to server: %v", err)
}
}()
// 从服务器到客户端
go func() {
defer wg.Done()
if _, err := io.Copy(smtp.client, smtp.target); err != nil {
log.Errorf("Error forwarding server to client: %v", err)
}
}()
wg.Wait()
log.Infof("Connection from %s closed", smtp.client.conn.RemoteAddr())
return
}
}
}
func main() {
// Configure logging
log.SetLevel(log.InfoLevel)
keyLogFile, _ := os.OpenFile("/tmp/tlskey.log", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
// Load certificate
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
log.Fatalf("Failed to load certificate: %v", err)
return
}
// TLS config for server
serverTLSConf = &tls.Config{
Certificates: []tls.Certificate{cert},
KeyLogWriter: keyLogFile,
}
// Create TCP listener
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
log.Fatalf("Failed to start server: %v", err)
}
defer listener.Close()
log.Infof("SMTP filter listening on %s", listenAddr)
// Accept connections
for {
conn, err := listener.Accept()
if err != nil {
log.Errorf("Failed to accept connection: %v", err)
continue
}
go handleConnection(conn)
}
}