#!/usr/bin/env python
# vim: ts=3 sw=3 ai
#
#  Log information about incoming SMTP connections.
#
#  Copyright (c) 2004-2007, Sean Reifschneider, tummy.com, ltd.
#  All Rights Reserved
#  <jafo@tummy.com>
#
# For code copied from pyspf, the following applies:
# Copyright (c) 2003, Terence Way
# Portions Copyright (c) 2004,2005,2006,2007,2008 Stuart Gathman <stuart@bmsi.com>
# Portions Copyright (c) 2005,2006,2007,2008,2011,2012 Scott Kitterman <scott@kitterman.com>
# This module is free software, and you may redistribute it and/or modify
# it under the same terms as Python itself, so long as this copyright message
# and disclaimer are retained in their original form.
#
# IN NO EVENT SHALL THE AUTHOR BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
# SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF
# THIS CODE, EVEN IF THE AUTHOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
#
# THE AUTHOR SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE.  THE CODE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS,
# AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE,
# SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.


S_rcsid = '$Id: tumgreyspf,v 1.29 2007-10-08 00:39:37 jafo Exp $'


import syslog, os, sys, string, re, time, popen2, urllib, stat, errno, socket
import spf
sys.path.append('/usr/lib/tumgreyspf')
import tumgreyspfsupp

syslog.openlog(os.path.basename(sys.argv[0]), syslog.LOG_PID, syslog.LOG_MAIL)
tumgreyspfsupp.setExceptHook()

#############################################
# Copied from pyspf 2.0.7
import struct

def addr2bin(str):
    """Convert a string IPv4 address into an unsigned integer.

    Examples::
    >>> import sys
    >>> if sys.version_info[0] == 2:
    ...     print(long(addr2bin('127.0.0.1')))
    ... else:
    ...     print(addr2bin('127.0.0.1'))
    2130706433

    >>> addr2bin('127.0.0.1') == socket.INADDR_LOOPBACK
    1

    >>> print(addr2bin('255.255.255.254'))
    4294967294

    >>> print(addr2bin('192.168.0.1'))
    3232235521

    Unlike DNS.addr2bin, the n, n.n, and n.n.n forms for IP addresses
    are handled as well::
    >>> import sys
    >>> if sys.version_info[0] == 2:
    ...     print(long(addr2bin('10.65536')))
    ... else:
    ...     print(addr2bin('10.65536'))
    167837696

    >>> import sys
    >>> if sys.version_info[0] == 2:
    ...     print(long(addr2bin('10.93.512')))
    ... else:
    ...     print(addr2bin('10.93.512'))
    173867520
    """
    return struct.unpack("!L", socket.inet_aton(str))[0]

def bin2long6(str):
    h, l = struct.unpack("!QQ", str)
    return h << 64 | l

if hasattr(socket,'has_ipv6') and socket.has_ipv6:
    def inet_ntop(s):
        return socket.inet_ntop(socket.AF_INET6,s)
    def inet_pton(s):
        return socket.inet_pton(socket.AF_INET6,s)
else:
    def inet_ntop(s):
      """Convert ip6 address to standard hex notation.
      Examples:
      >>> inet_ntop(struct.pack("!HHHHHHHH",0,0,0,0,0,0xFFFF,0x0102,0x0304))
      '::FFFF:1.2.3.4'
      >>> inet_ntop(struct.pack("!HHHHHHHH",0x1234,0x5678,0,0,0,0,0x0102,0x0304))
      '1234:5678::102:304'
      >>> inet_ntop(struct.pack("!HHHHHHHH",0,0,0,0x1234,0x5678,0,0x0102,0x0304))
      '::1234:5678:0:102:304'
      >>> inet_ntop(struct.pack("!HHHHHHHH",0x1234,0x5678,0,0x0102,0x0304,0,0,0))
      '1234:5678:0:102:304::'
      >>> inet_ntop(struct.pack("!HHHHHHHH",0,0,0,0,0,0,0,0))
      '::'
      """
      # convert to 8 words
      a = struct.unpack("!HHHHHHHH",s)
      n = (0,0,0,0,0,0,0,0)     # null ip6
      if a == n: return '::'
      # check for ip4 mapped
      if a[:5] == (0,0,0,0,0) and a[5] in (0,0xFFFF):
        ip4 = '.'.join([str(i) for i in struct.unpack("!HHHHHHBBBB",s)[6:]])
        if a[5]:
          return "::FFFF:" + ip4
        return "::" + ip4
      # find index of longest sequence of 0
      for l in (7,6,5,4,3,2,1):
        e = n[:l]
        for i in range(9-l):
          if a[i:i+l] == e:
            if i == 0:
              return ':'+':%x'*(8-l) % a[l:]
            if i == 8 - l:
              return '%x:'*(8-l) % a[:-l] + ':'
            return '%x:'*i % a[:i] + ':%x'*(8-l-i) % a[i+l:]
      return "%x:%x:%x:%x:%x:%x:%x:%x" % a

    def inet_pton(p):
      """Convert ip6 standard hex notation to ip6 address.
      Examples:
      >>> struct.unpack('!HHHHHHHH',inet_pton('::'))
      (0, 0, 0, 0, 0, 0, 0, 0)
      >>> struct.unpack('!HHHHHHHH',inet_pton('::1234'))
      (0, 0, 0, 0, 0, 0, 0, 4660)
      >>> struct.unpack('!HHHHHHHH',inet_pton('1234::'))
      (4660, 0, 0, 0, 0, 0, 0, 0)
      >>> struct.unpack('!HHHHHHHH',inet_pton('1234::5678'))
      (4660, 0, 0, 0, 0, 0, 0, 22136)
      >>> struct.unpack('!HHHHHHHH',inet_pton('::FFFF:1.2.3.4'))
      (0, 0, 0, 0, 0, 65535, 258, 772)
      >>> struct.unpack('!HHHHHHHH',inet_pton('1.2.3.4'))
      (0, 0, 0, 0, 0, 65535, 258, 772)
      >>> try: inet_pton('::1.2.3.4.5')
      ... except ValueError,x: print x
      ::1.2.3.4.5
      """
      if p == '::':
        return '\0'*16
      s = p
      m = RE_IP4.search(s)
      try:
          if m:
              pos = m.start()
              ip4 = [int(i) for i in s[pos:].split('.')]
              if not pos:
                  return struct.pack('!QLBBBB',0,65535,*ip4)
              s = s[:pos]+'%x%02x:%x%02x'%tuple(ip4)
          a = s.split('::')
          if len(a) == 2:
            l,r = a
            if not l:
              r = r.split(':')
              return struct.pack('!HHHHHHHH',
                *[0]*(8-len(r)) + [int(s,16) for s in r])
            if not r:
              l = l.split(':')
              return struct.pack('!HHHHHHHH',
                *[int(s,16) for s in l] + [0]*(8-len(l)))
            l = l.split(':')
            r = r.split(':')
            return struct.pack('!HHHHHHHH',
                *[int(s,16) for s in l] + [0]*(8-len(l)-len(r))
                + [int(s,16) for s in r])
          if len(a) == 1:
            return struct.pack('!HHHHHHHH',
                *[int(s,16) for s in a[0].split(':')])
      except ValueError: pass
      raise ValueError(p)

#############################################
def cidrmatch(connectip, ipaddrs, n):
    """Match connect IP against a list of other IP addresses. From pyspf."""

    try:
        if connectip.count(':'):
            MASK = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFL
            connectip = inet_pton(connectip)
            for arg in ipaddrs:
                ipaddrs[ipaddrs.index(arg)] = inet_pton(arg)
            bin = bin2long6
        else:
            MASK = 0xFFFFFFFFL
            bin = addr2bin
        c = ~(MASK >> n) & MASK & bin(connectip)
        for ip in [bin(ip) for ip in ipaddrs]:
            if c == ~(MASK >> n) & MASK & ip: return True
    except socket.error: pass
    return False

def parse_cidr(cidr_ip):
    """Breaks CIDR notation into a (address,cidr,cidr6) tuple.  The cidr 
       defaults to 32 if not present. Derived from pyspf"""
    import re
    RE_DUAL_CIDR = re.compile(r'//(0|[1-9]\d*)$')
    RE_CIDR = re.compile(r'/(0|[1-9]\d*)$')
    a = RE_DUAL_CIDR.split(cidr_ip)
    if len(a) == 3:
        cidr_ip, cidr6 = a[0], int(a[1])
    else:
        cidr6 = None
    a = RE_CIDR.split(cidr_ip)
    if len(a) == 3:
        cidr_ip, cidr = a[0], int(a[1])
    else:
        cidr = None
    b = cidr_ip.split(':', 1)
    if len(b) < 2:
        return cidr_ip, cidr
    return a[0], cidr6

#############################################
def spfcheck(data, configData, configGlobal):  #{{{1
	debugLevel = configGlobal.get('debugLevel', 0)
	queue_id = ('QUEUE_ID="%s"' % ( data.get('queue_id', '<UNKNOWN>')))
	ip = data.get('client_address')
	if ip == None:
		if debugLevel: syslog.syslog('spfcheck: No client address, exiting')
		return(( None, None ))
	# Do not check SPF for localhost addresses - add to skip addresses to 
	# skip SPF for internal networks if desired.
	skip_addresses = ['127.0.0.0/8', '::ffff:127.0.0.0//104', '::1//128',]
	for cidr in skip_addresses:
		parsed_address = parse_cidr(cidr)
		good_ip = [parsed_address[0],]
		if cidrmatch(ip, good_ip, int(parsed_address[1])):
			return (( None, 'SPF check N/A for local connections' ))

	sender = data.get('sender')
	helo = data.get('helo_name')
	if not sender and not helo:
		if debugLevel: syslog.syslog('spfcheck: No sender or helo, exiting')
		return(( None, None ))

	#  if no helo name sent, use domain from sender
	if not helo:
		foo = string.split(sender, '@', 1)
		if len(foo) <  2: helo = 'unknown'
		else: helo = foo[1]

	#  start query
	spfResult = None
	spfReason = None

	if hasattr(spf, 'check2'):
		#  use the pySPF 2.0/RFC 4408 interface
		try:
			ret = spf.check2(i = ip, s = sender, h = helo)
		except Exception, e:
			header = 'TumGreySPF-Warning: SPF Check failed: %s' % str(e)
			return(( 'prepend', header ))

		spfReason = repr(str(string.strip(ret[1])))
		spfResult = string.strip(ret[0])
		spfResult = spfResult.lower().capitalize()
		if spfResult == 'None':
			#  allow checking to continue on to other checkers if no SPF
			return(( None, None ))
		if sender:
			identity = 'identity=mailfrom; '
		else:
			identity = 'identity=helo; '
		spfDetail = (identity + 'client-ip=%s; helo=%s; envelope-from=%s; '
				'receiver=%s; '
				% ( data.get('client_address', '<UNKNOWN>'),
					data.get('helo_name', '<UNKNOWN>'),
					data.get('sender', '<UNKNOWN>'),
					data.get('recipient', '<UNKNOWN>'),
					))
		syslog.syslog('%s: %s; %s' % ( spfReason, queue_id, spfDetail ))
		
		if spfResult == 'Fail':
			return(( 'reject', '%s SPF Reports: %s' % ( queue_id, spfReason )))

		spfAcceptOnPermError = configGlobal.get('spfAcceptOnPermError', 1)
		if spfResult == 'Permerror' and not spfAcceptOnPermError:
			return(( 'reject', '%s SPF Reports: %s' % ( queue_id, spfReason )))

		if spfResult == 'Temperror':
			return(( 'defer', '%s SPF Reports: %s' % ( queue_id, spfReason )))

		header = ('Received-SPF: '+ spfResult + ' (' + spfReason
				+ ') ' + spfDetail)
		
		return(( 'prepend', header ))
	else:
		#  fall back to older pre-RFC interface
		try:
			ret = spf.check(i = ip, s = sender, h = helo)
		except Exception, e:
			header = 'TumGreySPF-Warning: SPF Check failed: %s' % str(e)
			return(( 'prepend', header ))

		spfResult = string.strip(ret[0])
		spfReason = repr(str(string.strip(ret[2])))

		#  try spfquery
		if not spfResult:
			#  check for spfquery
			spfqueryPath = configGlobal['spfqueryPath']
			if not os.path.exists(spfqueryPath):
				if debugLevel:
					syslog.syslog('spfcheck: No spfquery at "%s", exiting'
							% spfqueryPath)
				return(( None, None ))

			#  open connection to spfquery
			fpIn, fpOut = popen2.popen2('%s -file -' % spfqueryPath)
			fpOut.write('%s %s %s\n' % ( ip, sender, helo ))
			fpOut.close()
			spfData = fpIn.readlines()
			fpIn.close()
			if debugLevel:
				syslog.syslog('spfcheck: spfquery result: "%s"' % str(spfData))
			spfResult = string.strip(spfData[0])
			spfReason = repr(str(string.strip(spfData[1])))

		#  read result
		if spfResult == 'fail' or spfResult == 'deny':
			syslog.syslog('SPF fail: REMOTEIP="%s" HELO="%s" SENDER="%s" '
					'RECIPIENT="%s" %s REASON="%s"'
					% ( data.get('client_address', '<UNKNOWN>'),
						data.get('helo_name', '<UNKNOWN>'),
						data.get('sender', '<UNKNOWN>'),
						data.get('recipient', '<UNKNOWN>'),
						queue_id, spfReason ) )

			return(( 'reject', '%s SPF Reports: %s'
					% ( queue_id, spfReason ) ))

	if debugLevel:
		syslog.syslog('spfcheck: pyspf result: "%s"' % str(ret))

	return(( None, None ))


##################################################
def greylistcheck(data, configData, configGlobal):  #{{{1
	greylistDir = configGlobal['greylistDir']

	ip = data.get('client_address')
	if ip == None:
		return(( None, None ))
	ipBytes = string.split(ip, '.')
	if configGlobal['ignoreLastByte'] > 0: ipBytes = ipBytes[:-1]
	ipPath = string.join(ipBytes, '/')
	
	if configGlobal['greylistByIPOnly'] > 0:
		dir = os.path.join(greylistDir, ipPath)
		path = os.path.join(dir, 'check_file')
	else:
		sender = data.get('sender')
		recipient = data.get('recipient')
		if not sender or not recipient:
			return(( None, None ))
		sender = tumgreyspfsupp.quoteAddress(sender)
		recipient = tumgreyspfsupp.quoteAddress(recipient)
		dir = os.path.join(greylistDir, 'client_address', ipPath, 'greylist',
				sender)
		path = os.path.join(dir, recipient)

	allowTime = configData.get('GREYLISTTIME', 600)

	if not os.path.exists(path):
		if not os.path.exists(dir):
			#  if multiple messages come in at once
			#  it can cause multiple makedirs
			for i in xrange(10):
				try:
					os.makedirs(dir)
					break
				except OSError, msg:
					if msg.errno != errno.EEXIST: raise
					time.sleep(1)


			#  still didn't succeed
			if not os.path.exists(dir):
				syslog.syslog(('ERROR: Could not create directory after '
						'10 seconds: "%s"') % dir)
				return(( 'defer', 'Service unavailable, error creating data '
						'directory.  See /var/log/maillog for more information.' ))

		#  create file
		open(path, 'w').close()
		now = time.time()
		mtime = now + allowTime
		os.utime(path, ( now, mtime ))

		if configGlobal.get('defaultSeedOnly'):
			syslog.syslog(
					'Training greylisting: REMOTEIP="%s" HELO="%s" SENDER="%s" '
					'RECIPIENT="%s" QUEUEID="%s"'
					% ( data.get('client_address', '<UNKNOWN>'),
						data.get('helo_name', '<UNKNOWN>'),
						data.get('sender', '<UNKNOWN>'),
						data.get('recipient', '<UNKNOWN>'),
						data.get('queue_id', '<UNKNOWN>'), ) )
			return(( None, None ))

		syslog.syslog('Initial greylisting: REMOTEIP="%s" HELO="%s" SENDER="%s" '
				'RECIPIENT="%s" QUEUEID="%s"'
				% ( data.get('client_address', '<UNKNOWN>'),
					data.get('helo_name', '<UNKNOWN>'),
					data.get('sender', '<UNKNOWN>'),
					data.get('recipient', '<UNKNOWN>'),
					data.get('queue_id', '<UNKNOWN>'), ) )

		return(( 'defer', 'Service unavailable, greylisted '
				'(http://projects.puremagic.com/greylisting/).' ))

	#  is it time to allow yet
	mtime = os.stat(path)[8]
	now = time.time()
	os.utime(path, ( now, mtime ))
	if mtime > now and not configGlobal.get('defaultSeedOnly'):
		syslog.syslog('Pending greylisting: REMOTEIP="%s" HELO="%s" SENDER="%s" '
				'RECIPIENT="%s" QUEUEID="%s"'
				% ( data.get('client_address', '<UNKNOWN>'),
					data.get('helo_name', '<UNKNOWN>'),
					data.get('sender', '<UNKNOWN>'),
					data.get('recipient', '<UNKNOWN>'),
					data.get('queue_id', '<UNKNOWN>'), ) )

		return(( 'defer', 'Service unavailable, greylisted.' ))

	syslog.syslog('Allowed greylisting: REMOTEIP="%s" HELO="%s" SENDER="%s" '
			'RECIPIENT="%s" QUEUEID="%s"'
			% ( data.get('client_address', '<UNKNOWN>'),
				data.get('helo_name', '<UNKNOWN>'),
				data.get('sender', '<UNKNOWN>'),
				data.get('recipient', '<UNKNOWN>'),
				data.get('queue_id', '<UNKNOWN>'), ) )

	return(( None, None ))


###################################################
def blackholecheck(data, configData, configGlobal):  #{{{1
	blackholeDir = configGlobal['blackholeDir']

	ip = data.get('client_address')
	if ip == None:
		return(( None, None ))
	ipPath = string.join(string.split(ip, '.'), '/')
	dir = os.path.join(blackholeDir, 'ips', ipPath)
	
	recipient = data.get('recipient')
	if not recipient:
		return(( None, None ))
	recipient = tumgreyspfsupp.quoteAddress(recipient)

	#  add blackhole
	recipientPath = os.path.join(blackholeDir, 'addresses', recipient)
	if os.path.exists(recipientPath):
		if not os.path.exists(dir):
			os.path.makedirs(dir)

	#  check for existing blackhole entry
	if os.path.exists(dir):
		syslog.syslog('Blackholed: REMOTEIP="%s" HELO="%s" SENDER="%s" '
				'RECIPIENT="%s" QUEUEID="%s"'
				% ( data.get('client_address', '<UNKNOWN>'),
					data.get('helo_name', '<UNKNOWN>'),
					data.get('sender', '<UNKNOWN>'),
					data.get('recipient', '<UNKNOWN>'),
					data.get('queue_id', '<UNKNOWN>'), ) )

		return(( 'reject', 'Service unavailable, blackholed.' ))

	return(( None, None ))


###################
#  load config file  {{{1
configFile = tumgreyspfsupp.defaultConfigFilename
if len(sys.argv) > 1:
	if sys.argv[1] in ( '-?', '--help', '-h' ):
		print 'usage: tumgreyspf [<configfilename>]'
		sys.exit(1)
	configFile = sys.argv[1]
configGlobal = tumgreyspfsupp.processConfigFile(filename = configFile)

#  loop reading data  {{{1
debugLevel = configGlobal.get('debugLevel', 0)
if debugLevel >= 2: syslog.syslog('Starting')
instance_list = []
data = {}
lineRx = re.compile(r'^\s*([^=\s]+)\s*=(.*)$')
while 1:
	line = sys.stdin.readline()
	if not line: break
	line = string.rstrip(line)
	if debugLevel >= 4: syslog.syslog('Read line: "%s"' % line)

	#  end of entry  {{{2
	if not line:
		if debugLevel >= 4: syslog.syslog('Found the end of entry')
		configData = tumgreyspfsupp.lookupConfig(configGlobal.get('configPath'),
				data, configGlobal)
		if debugLevel >= 2: syslog.syslog('Config: %s' % str(configData))

		#  run the checkers  {{{3
		checkerValue = None
		checkerReason = None
		for checkerType in string.split(configData.get('CHECKERS', ''), ','):
			checkerType = string.strip(checkerType)

			if checkerType == 'greylist':
				checkerValue, checkerReason = greylistcheck(data, configData,
						configGlobal)
				if checkerValue != None: break
			elif checkerType == 'spf':
				checkerValue, checkerReason = spfcheck(data, configData,
						configGlobal)
				if configData.get('SPFSEEDONLY', 0):
					checkerValue = None
					checkerReason = None
				if checkerValue != None and checkerValue != 'prepend': break
			elif checkerType == 'blackhole':
				checkerValue, checkerReason = blackholecheck(data, configData,
						configGlobal)
				if checkerValue != None: break

		#  handle results  {{{3

		if checkerValue == 'reject':
			sys.stdout.write('action=550 %s\n\n' % checkerReason)
			
		elif checkerValue == 'prepend':
			instance = data.get('instance')
			# The following if is only needed for testing.  Postfix 
			# will always provide instance.
			if not instance:
				import random
				instance = str(int(random.random()*100000))
			# This is to prevent multiple headers being prepended
			# for multi-recipient mail.
			found_instance = instance_list.count(instance)
			if found_instance == 0:
				sys.stdout.write('action=prepend %s\n\n' % checkerReason)
				instance_list.append(instance)
			else:
				sys.stdout.write('action=dunno\n\n')
		elif checkerValue == 'defer':
			sys.stdout.write('action=defer_if_permit %s\n\n' % checkerReason)	
		else:
			sys.stdout.write('action=dunno\n\n')

		#  end of record  {{{3
		sys.stdout.flush()
		data = {}
		continue

	#  parse line  {{{2
	m = lineRx.match(line)
	if not m: 
		syslog.syslog('ERROR: Could not match line "%s"' % line)
		continue

	#  save the string  {{{2
	key = m.group(1)
	value = m.group(2)
	if key not in [ 'protocol_state', 'protocol_name', 'queue_id' ]:
		value = string.lower(value)
	data[key] = value
