269 lines
6.9 KiB
Go
269 lines
6.9 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"crypto/tls"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
listenAddr = ":2525"
|
|
targetAddr = "localhost:25"
|
|
allowDomain = "libertarian.dev"
|
|
timeout = 30 * time.Second
|
|
certFile = "/home/haswell/Projects/mailu-deploy/mailu/certs/cert.pem"
|
|
keyFile = "/home/haswell/Projects/mailu-deploy/mailu/certs/key.pem"
|
|
)
|
|
|
|
var (
|
|
serverTLSConf *tls.Config
|
|
)
|
|
|
|
var re = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
|
|
|
type timeoutConn struct {
|
|
conn net.Conn
|
|
}
|
|
|
|
func (c timeoutConn) Read(buf []byte) (int, error) {
|
|
c.conn.SetDeadline(time.Now().Add(time.Duration(30) * time.Second))
|
|
return c.conn.Read(buf)
|
|
}
|
|
func (c timeoutConn) Write(buf []byte) (int, error) {
|
|
c.conn.SetDeadline(time.Now().Add(time.Duration(30) * time.Second))
|
|
return c.conn.Write(buf)
|
|
}
|
|
|
|
type smtpConnection struct {
|
|
client timeoutConn
|
|
target timeoutConn
|
|
reader *bufio.Reader
|
|
clientAddr string
|
|
isRcptValid bool
|
|
}
|
|
|
|
func newSMTPConnection(clientConn net.Conn) (*smtpConnection, error) {
|
|
targetConn, err := net.Dial("tcp", targetAddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &smtpConnection{
|
|
client: timeoutConn{clientConn},
|
|
target: timeoutConn{targetConn},
|
|
reader: bufio.NewReader(clientConn),
|
|
clientAddr: clientConn.RemoteAddr().String(),
|
|
isRcptValid: false,
|
|
}, nil
|
|
}
|
|
|
|
func (s *smtpConnection) close() {
|
|
s.client.conn.Close()
|
|
s.target.conn.Close()
|
|
}
|
|
|
|
func (s *smtpConnection) upgradeToTLS() error {
|
|
// Upgrade client connection
|
|
if _, err := s.client.Write([]byte("220 Ready to start TLS\r\n")); err != nil {
|
|
return err
|
|
}
|
|
s.client.conn = tls.Server(s.client.conn, serverTLSConf)
|
|
|
|
// Reset reader for new TLS connection
|
|
s.reader = bufio.NewReader(s.client)
|
|
return nil
|
|
}
|
|
|
|
func (s *smtpConnection) readCommand() (string, error) {
|
|
line, err := s.reader.ReadString('\n')
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return line, nil
|
|
}
|
|
|
|
func handleConnection(clientConn net.Conn) {
|
|
smtp, err := newSMTPConnection(clientConn)
|
|
if err != nil {
|
|
log.Errorf("Failed to create SMTP connection: %v", err)
|
|
clientConn.Close()
|
|
return
|
|
}
|
|
defer smtp.close()
|
|
|
|
log.Infof("New connection from %s", smtp.clientAddr)
|
|
|
|
// Forward initial greeting
|
|
smtp.target.Write([]byte(""))
|
|
response := make([]byte, 1024)
|
|
n, err := smtp.target.Read(response)
|
|
if err != nil {
|
|
log.Errorf("Failed to get greeting: %v", err)
|
|
return
|
|
}
|
|
smtp.client.Write([]byte(response[:n]))
|
|
for {
|
|
if !smtp.isRcptValid {
|
|
line, err := smtp.readCommand()
|
|
if err != nil {
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
log.Warnf("Client timeout from %s: %v", smtp.clientAddr, err)
|
|
} else {
|
|
log.Errorf("Error reading from client %s: %v", smtp.clientAddr, err)
|
|
}
|
|
return
|
|
}
|
|
|
|
cmd := strings.ToUpper(strings.TrimSpace(line))
|
|
log.Debugf("Received command: %s", cmd)
|
|
|
|
// Handle STARTTLS before anything else
|
|
if strings.HasPrefix(cmd, "STARTTLS") {
|
|
|
|
if err := smtp.upgradeToTLS(); err != nil {
|
|
log.Errorf("Failed to send STARTTLS response: %v", err)
|
|
return
|
|
}
|
|
log.Infof("Connection upgraded to TLS for %s", smtp.clientAddr)
|
|
continue
|
|
}
|
|
if strings.HasPrefix(cmd, "AUTH") {
|
|
smtp.client.Write([]byte("554 5.7.1 Access denied\r\n"))
|
|
log.Warnf("Rejected recipient from %s: external auth denied", smtp.clientAddr)
|
|
return
|
|
}
|
|
// Handle RCPT TO validation inside TLS if needed
|
|
if strings.HasPrefix(cmd, "RCPT TO:") {
|
|
parts := strings.Split(strings.ToLower(line), ":")
|
|
if len(parts) != 2 {
|
|
smtp.client.Write([]byte("501 Syntax error in parameters or arguments\r\n"))
|
|
return
|
|
}
|
|
|
|
// email := strings.Trim(strings.TrimSpace(parts[1]), "<>")
|
|
// domain := strings.Split(email, "@")
|
|
// if len(domain) != 2 {
|
|
// smtp.client.Write([]byte("501 Syntax error in parameters or arguments\r\n"))
|
|
// return
|
|
// }
|
|
// fmt.Sscanf(parts[1], "<%s>%s", &email, &whatever)
|
|
// email := re.FindStringSubmatch(parts[1])
|
|
// if len(email) < 1 {
|
|
// log.Errorf("%s is not email address", email[1])
|
|
// return
|
|
// }
|
|
// domain := strings.Split(email[0], "@")
|
|
start := strings.Index(parts[1], "<")
|
|
end := strings.Index(parts[1], ">")
|
|
if start == -1 || end == -1 || start >= end {
|
|
smtp.client.Write([]byte("501 Syntax error in parameters or arguments\r\n"))
|
|
return
|
|
}
|
|
|
|
email := parts[1][start+1 : end]
|
|
domain := strings.Split(email, "@")
|
|
if len(domain) != 2 {
|
|
smtp.client.Write([]byte("501 Syntax error in parameters or arguments\r\n"))
|
|
return
|
|
}
|
|
if domain[1] != allowDomain {
|
|
log.Warnf("Rejected recipient from %s: domain %s not allowed", smtp.clientAddr, domain[1])
|
|
smtp.client.Write([]byte("554 Domain not allowed\r\n"))
|
|
return
|
|
}
|
|
smtp.isRcptValid = true
|
|
}
|
|
|
|
// Check DATA command
|
|
if strings.HasPrefix(cmd, "DATA") && !smtp.isRcptValid {
|
|
log.Warnf("Disconnecting %s - attempted DATA without valid recipient", smtp.clientAddr)
|
|
return
|
|
}
|
|
//Otherwise Forward to target and get response
|
|
smtp.target.Write([]byte(line))
|
|
n, err := smtp.target.Read(response)
|
|
if err != nil {
|
|
log.Errorf("Failed to read from target: %v", err) // Added log output
|
|
return
|
|
}
|
|
smtp.client.Write([]byte(response[:n]))
|
|
} else {
|
|
for smtp.reader.Buffered() > 0 {
|
|
line, _ := smtp.readCommand()
|
|
smtp.target.Write([]byte(line))
|
|
n, err := smtp.target.Read(response)
|
|
if err != nil {
|
|
log.Errorf("Failed to read from target: %v", err) // Added log output
|
|
return
|
|
}
|
|
smtp.client.Write([]byte(response[:n]))
|
|
}
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
go func() {
|
|
defer wg.Done()
|
|
if _, err := io.Copy(smtp.target, smtp.client); err != nil {
|
|
log.Errorf("Error forwarding client to server: %v", err)
|
|
}
|
|
}()
|
|
|
|
// 从服务器到客户端
|
|
go func() {
|
|
defer wg.Done()
|
|
if _, err := io.Copy(smtp.client, smtp.target); err != nil {
|
|
log.Errorf("Error forwarding server to client: %v", err)
|
|
}
|
|
}()
|
|
|
|
wg.Wait()
|
|
log.Infof("Connection from %s closed", smtp.client.conn.RemoteAddr())
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
// Configure logging
|
|
log.SetLevel(log.InfoLevel)
|
|
keyLogFile, _ := os.OpenFile("/tmp/tlskey.log", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
|
// Load certificate
|
|
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
|
if err != nil {
|
|
log.Fatalf("Failed to load certificate: %v", err)
|
|
return
|
|
}
|
|
|
|
// TLS config for server
|
|
serverTLSConf = &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
KeyLogWriter: keyLogFile,
|
|
}
|
|
|
|
// Create TCP listener
|
|
listener, err := net.Listen("tcp", listenAddr)
|
|
if err != nil {
|
|
log.Fatalf("Failed to start server: %v", err)
|
|
}
|
|
defer listener.Close()
|
|
|
|
log.Infof("SMTP filter listening on %s", listenAddr)
|
|
|
|
// Accept connections
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
log.Errorf("Failed to accept connection: %v", err)
|
|
continue
|
|
}
|
|
go handleConnection(conn)
|
|
}
|
|
}
|