236 lines
5.3 KiB
Go
236 lines
5.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rc4"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/binary"
|
|
"sync"
|
|
)
|
|
|
|
// EncryptorConfig holds configuration for encryption
|
|
type EncryptorConfig struct {
|
|
RSAKeySize int // RSA key size in bits
|
|
KeyExchangeOp byte // Opcode for key exchange packets
|
|
MinSize int // Minimum packet size to encrypt
|
|
}
|
|
|
|
// DefaultEncryptorConfig returns default encryptor configuration
|
|
func DefaultEncryptorConfig() *EncryptorConfig {
|
|
return &EncryptorConfig{
|
|
RSAKeySize: 1024,
|
|
KeyExchangeOp: 0x21, // OP_WSLoginRequestMsg equivalent
|
|
MinSize: 8,
|
|
}
|
|
}
|
|
|
|
// Encryptor implements RC4 + RSA encryption middleware
|
|
type Encryptor struct {
|
|
config *EncryptorConfig
|
|
rsaKey *rsa.PrivateKey
|
|
rc4Key []byte
|
|
cipher *rc4.Cipher
|
|
cipherMux sync.RWMutex
|
|
keySet bool
|
|
closeOnce sync.Once
|
|
}
|
|
|
|
// NewEncryptor creates a new encryption middleware
|
|
func NewEncryptor(config *EncryptorConfig) *Encryptor {
|
|
if config == nil {
|
|
config = DefaultEncryptorConfig()
|
|
}
|
|
|
|
// Generate RSA key pair
|
|
rsaKey, err := rsa.GenerateKey(rand.Reader, config.RSAKeySize)
|
|
if err != nil {
|
|
panic(err) // Should handle this better in production
|
|
}
|
|
|
|
return &Encryptor{
|
|
config: config,
|
|
rsaKey: rsaKey,
|
|
}
|
|
}
|
|
|
|
// ProcessOutbound implements Middleware.ProcessOutbound
|
|
func (e *Encryptor) ProcessOutbound(data []byte, next func([]byte) (int, error)) (int, error) {
|
|
// Check if this is a key exchange request
|
|
if len(data) > 4 && data[0] == 0 && data[1] == e.config.KeyExchangeOp {
|
|
return e.handleKeyExchange(data, next)
|
|
}
|
|
|
|
// Skip encryption for small packets or if no key is set
|
|
if len(data) < e.config.MinSize || !e.isKeySet() {
|
|
return next(data)
|
|
}
|
|
|
|
encrypted, err := e.encrypt(data)
|
|
if err != nil {
|
|
return next(data) // Fallback to unencrypted
|
|
}
|
|
|
|
return next(encrypted)
|
|
}
|
|
|
|
// ProcessInbound implements Middleware.ProcessInbound
|
|
func (e *Encryptor) ProcessInbound(data []byte, next func([]byte) (int, error)) (int, error) {
|
|
// Check for RSA encrypted key at end of packet
|
|
if len(data) >= 8 && e.isRSAEncryptedKey(data) {
|
|
return e.processRSAKey(data, next)
|
|
}
|
|
|
|
// Skip decryption if no key is set
|
|
if !e.isKeySet() {
|
|
return next(data)
|
|
}
|
|
|
|
decrypted, err := e.decrypt(data)
|
|
if err != nil {
|
|
return next(data) // Fallback to unencrypted
|
|
}
|
|
|
|
return next(decrypted)
|
|
}
|
|
|
|
func (e *Encryptor) handleKeyExchange(data []byte, next func([]byte) (int, error)) (int, error) {
|
|
// Extract key size from packet
|
|
if len(data) < 8 {
|
|
return next(data)
|
|
}
|
|
|
|
keySize := binary.LittleEndian.Uint32(data[4:8])
|
|
if keySize != 60 { // Expected key size
|
|
return next(data)
|
|
}
|
|
|
|
// Create key exchange response with RSA public key
|
|
response := make([]byte, len(data))
|
|
copy(response, data)
|
|
|
|
// Fill with dummy key data (in real implementation, would use proper key)
|
|
for i := 8; i < len(response)-8; i++ {
|
|
response[i] = 0xFF
|
|
}
|
|
|
|
// Add termination markers
|
|
response[len(response)-5] = 1
|
|
response[len(response)-1] = 1
|
|
|
|
return next(response)
|
|
}
|
|
|
|
func (e *Encryptor) processRSAKey(data []byte, next func([]byte) (int, error)) (int, error) {
|
|
// Extract and decrypt RSA key from end of packet
|
|
encryptedKey := data[len(data)-8:]
|
|
|
|
// In real implementation, would decrypt with RSA private key
|
|
// For now, use a simple XOR pattern
|
|
rc4Key := make([]byte, 8)
|
|
for i := 0; i < 8; i++ {
|
|
rc4Key[i] = encryptedKey[i] ^ 0x55 // Simple pattern
|
|
}
|
|
|
|
e.setRC4Key(rc4Key)
|
|
|
|
// Pass the packet without the RSA key
|
|
return next(data[:len(data)-8])
|
|
}
|
|
|
|
func (e *Encryptor) isRSAEncryptedKey(data []byte) bool {
|
|
// Simple heuristic - check if last 8 bytes look like encrypted data
|
|
if len(data) < 8 {
|
|
return false
|
|
}
|
|
|
|
// Check for non-zero data in last 8 bytes
|
|
lastBytes := data[len(data)-8:]
|
|
nonZero := 0
|
|
for _, b := range lastBytes {
|
|
if b != 0 {
|
|
nonZero++
|
|
}
|
|
}
|
|
return nonZero > 4 // Heuristic: encrypted data should have some non-zero bytes
|
|
}
|
|
|
|
func (e *Encryptor) setRC4Key(key []byte) {
|
|
e.cipherMux.Lock()
|
|
defer e.cipherMux.Unlock()
|
|
|
|
e.rc4Key = make([]byte, len(key))
|
|
copy(e.rc4Key, key)
|
|
|
|
cipher, err := rc4.NewCipher(key)
|
|
if err == nil {
|
|
e.cipher = cipher
|
|
e.keySet = true
|
|
}
|
|
}
|
|
|
|
func (e *Encryptor) isKeySet() bool {
|
|
e.cipherMux.RLock()
|
|
defer e.cipherMux.RUnlock()
|
|
return e.keySet
|
|
}
|
|
|
|
func (e *Encryptor) encrypt(data []byte) ([]byte, error) {
|
|
e.cipherMux.Lock()
|
|
defer e.cipherMux.Unlock()
|
|
|
|
if e.cipher == nil {
|
|
return data, nil
|
|
}
|
|
|
|
// Create new cipher for this operation (RC4 is stateful)
|
|
cipher, err := rc4.NewCipher(e.rc4Key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
encrypted := make([]byte, len(data))
|
|
cipher.XORKeyStream(encrypted, data)
|
|
return encrypted, nil
|
|
}
|
|
|
|
func (e *Encryptor) decrypt(data []byte) ([]byte, error) {
|
|
e.cipherMux.Lock()
|
|
defer e.cipherMux.Unlock()
|
|
|
|
if e.cipher == nil {
|
|
return data, nil
|
|
}
|
|
|
|
// Create new cipher for this operation (RC4 is stateful)
|
|
cipher, err := rc4.NewCipher(e.rc4Key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
decrypted := make([]byte, len(data))
|
|
cipher.XORKeyStream(decrypted, data)
|
|
return decrypted, nil
|
|
}
|
|
|
|
// GetPublicKey returns the RSA public key for key exchange
|
|
func (e *Encryptor) GetPublicKey() []byte {
|
|
pubKeyBytes, err := x509.MarshalPKIXPublicKey(&e.rsaKey.PublicKey)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return pubKeyBytes
|
|
}
|
|
|
|
// Close implements Middleware.Close
|
|
func (e *Encryptor) Close() error {
|
|
e.closeOnce.Do(func() {
|
|
e.cipherMux.Lock()
|
|
e.cipher = nil
|
|
e.rc4Key = nil
|
|
e.keySet = false
|
|
e.cipherMux.Unlock()
|
|
})
|
|
return nil
|
|
}
|