#!/usr/bin/python

# AttackReport.py - Analyze the log generated by AttackLog.py
# Copyright (c) 2003 - John W. Peterson (linux AT saccade DOT com)
# Permission is granted to copy and use, so long as this copywrite notice
# remains intact, and any improvements are forwarded back to the original author.

import re, time, string, sys

class ItemCounter(dict):
	"""A dictionary class that counts the number of keys encountered"""
	def countItem(self, key):
		"""Increment the count of KEY (adding it to the dict if it's not there yet"""
		if self.has_key(key):
			self[key] += 1
		else:
			self[key] = 1

class ItemList(dict):
	"""Manage lists of items as a dict."""
	def __init__( self, itemList, tupleIndex ):
		"""Create a key for the i'th item in each tuple, and add the tuple to the list for each item."""
		for i in itemList: self.addItem( i[tupleIndex], i )
		
	def addItem(self, key, trans):
		"""Add TRANS to the list for KEY."""
		if self.has_key(key):
			self[key].append(trans)
		else:
			self[key] = [trans]

class AttackList(list):
	"""Subclass of list for managing 'attacks'.  Each attack is a tuple
with for items, the time (time.time format), the port name, the type of
message (TCP or UDP) and the host name it originated from."""
	# Tuple indicies
	TIME = 0
	PORT = 1
	MSGTYPE = 2
	HOST = 3

	def __init__( self, filename = None ):
		"""Create an AttackList by parsing the file generated by AttackLog.py.
If FILENAME is None, then an empty list is created."""
		if (filename == None):
			return

		if (isinstance( filename, list )):
			self.extend(filename)
			return
		
		logRegexp = re.compile('.*\(([0-9.]+)\) ([^ ]+) \(([TU][CD]P)\) (.*)$')
		f = file(filename,'r')	
		while 1:
			s = f.readline()
			if (s == ''): break
			if (s[0:2] == '--'): continue	## Skip timestamps
			attackInfo = logRegexp.search(s)
			if (attackInfo):
				attackTime = float(attackInfo.group(1))
				portName, msgType, hostname = attackInfo.groups()[1:]
				attack = tuple([attackTime, portName, msgType, hostname])
				self.append(attack)
		f.close()

	def FindTriples(self):
		"""Find the number of multiple attacks with the same port/host in a 15 second window."""
		triples = AttackList()
		if len(self) > 3:
			for i in range(3, len(self)):
				sameHost = (self[i-1][self.HOST] == self[i-2][self.HOST]) and (self[i-1][self.HOST] == self[i-3][self.HOST])
				sameTime = (self[i-1][self.TIME] - self[i-3][self.TIME]) < 15
				if (sameHost and sameTime):
					triples.append(self[i-1])
		return triples

	def DNSSortedHosts(self):
		def reverseName(name,keep = 0):
			name = name.split('.')
			name.reverse()
			if keep > 0: return string.join(name[:keep], '.')
			else: return string.join(name,'.')
			
		hostNames = ItemList( self, self.HOST )
		unknown = 'UNKNOWN ('
		domains = {}
		for i in list(hostNames):
			if i[:len(unknown)] != unknown:
				domains[reverseName(i, 3)] = None
		domains = domains.keys()
		domains.sort()
		domains = map( reverseName, domains )
		for i in domains:
			print "%40s" % i

	def DomainHistogram(self, level = 1, minCount = 4 ):
		"""Print a list of hosts attacks orignate from.  LEVEL is how many DNS subdomains to
track, and a domain must appear at least MINCOUNT times to appear in the printout."""
		def reverseName(name, keep):
			unknown = 'UNKNOWN'
			if (name[:len(unknown)] == unknown): return unknown
			name = name.split('.')
			if (len(name) == 1): return name[0]
			name.reverse()
			if (level > 0) and (name[1] in ['edu', 'com', 'net', 'org']):	# Some countries have US style TLDs
				keep += 1
			elif (level == 0):
				keep = 1
			name = name[:keep]
			name.reverse()
			return string.join(name[:keep], '.')

		domainCounts = ItemCounter()
		for i in self:
			domainCounts.countItem( reverseName( i[self.HOST], level ) )
		domainList = domainCounts.keys()
	#	domainList.sort()
		domainList.sort((lambda a,b: domainCounts[b] - domainCounts[a]))
		for i in domainList:
			if domainCounts[i] > minCount:
				print "%30s: %4d (%5.2f%%)" % (i, domainCounts[i], float(domainCounts[i]) * 100.0 / len(self))

	def DayHistogram(self):
		"""Print a histogram of attacks by day."""
		firstDay = time.localtime(self[0][self.TIME])[7]
		lastDay = time.localtime(self[-1][self.TIME])[7]
		if (lastDay < firstDay): lastDay += 365		# Crossed year boundary (bogus for leap years)
		dayHistogram = [0 for i in range(lastDay-firstDay+1)]
		for i in self:
			julDay = time.localtime(i[self.TIME])[7] - firstDay
			if julDay < 0: julDay += 365
			dayHistogram[julDay] += 1

		firstDate = self[0][self.TIME]
		for i in range(len(dayHistogram)):
			dateStr = time.strftime("%a, %d-%b-%y", time.localtime(firstDate + i*(60*60*24)))
			print "%s : %3d (%4.1f%%)" % (dateStr, dayHistogram[i], dayHistogram[i]* 100.0 / len(self) )

	def HourHistogram(self):
		"""Print a histogram of attacks by hour in the day."""
		hourHistogram = [0 for i in range(24)]
		for i in self:
			hourHistogram[time.localtime(i[self.TIME])[3]] += 1
		for h in range(len(hourHistogram)):
			print "%02d:00 - %02d:59 = %4d (%5.2f%%)" % (h,h,hourHistogram[h], hourHistogram[h]*100.0/len(self))

	def AttackFreqency(self):
		"""Print overall attack freqency (NOTE: assumes contigous time span in logfile)."""
		totalTime = self[-1][self.TIME] - self[0][self.TIME]
		secsPerAttack = totalTime / len(self)
		print "Attacked %3.1f times a day" % (len(self) / (totalTime / (24*60*60.0)),)
		print "Attacked every %d:%02d minutes" % (int(secsPerAttack / 60), int(secsPerAttack) % 60)

	def GetMatching(self, item):
		"""Pull items from the logfile matching ITEM (e.g., a host or port), and return a new AttackList"""
		return AttackList(filter((lambda x:list(x).count(item)), self))

	def PrintMatching(self, item):
		for i in self.GetMatching(item):
			print "%s: %-18s (%3s) %s" % (time.strftime("%d-%b-%y %H:%M:%S",time.localtime(i[0])), i[1], i[2], i[3])

	def PrintAll(self):
		for i in self:
			print "%s: %-18s (%3s) %s" % (time.strftime("%d-%b-%y %H:%M:%S",time.localtime(i[0])), i[1], i[2], i[3])

	def FrequentAttackers(self, minimumAttacks = 3):
		"""Print hosts that attack more than minimumAttacks times."""
		hostNames = ItemList( self, self.HOST )
		freqs = []
		for i in list(hostNames):
			if (len(hostNames[i]) > minimumAttacks):
				freqs.append((len(hostNames[i]), i))
		freqs.sort((lambda a,b: b[0] - a[0]))
		for i in freqs:
			print "%3d : %s" % (i[0], i[1])

	def PortReport(self):
		"""Print a list of ports attacked."""
		ports = ItemList( self, self.PORT )
		sortPorts = list(ports)
		sortPorts.sort((lambda a,b: len(ports[b]) - len(ports[a])))
		for i in sortPorts:
			print "%-18s = %5d (%5.2f%%)" % (i, len(ports[i]), (len(ports[i])*100.0/len(self)))

# al = AttackList('AttackLog.txt')

if (len(sys.argv) > 1):
	al = AttackList(sys.argv[1])
	print "--Summary:"
	al.AttackFreqency()
	print "Percent of Triple Attacks: %4.1f%%" % (len(al.FindTriples()) * 100.0 / len(al),)
	print "--Frequent Attackers:"
	al.FrequentAttackers()
	print "--Attacks by time of day:"
	al.HourHistogram()
	print "--Attacks by day:"
	al.DayHistogram()
	print "--Ports attacked"
	al.PortReport()
	print "--Top level Domains:"
	al.DomainHistogram(0,4)
else:
	print "Usage: AttackReport.py logfile"
