#!/usr/bin/env python
# NAME:		qclib.py
# PURPOSE:	providing classes and functions used for QC scripts
# AUTHOR:	Burkhard Wolff, ESO-DMO
# VERSIONS:	1.0	-- July 2008
#		1.0.1	-- arr_stddev, LinePlot: fix for sigma==0 (2008-08-07)
#		1.0.2	-- ImagePlot: addcol, etc. added to draw (2008-08-08)
#		1.0.3	-- arr_stddev: median can be passed as parameter;
#			   ImagePlot: faster plotting of sub images (2008-10-09)
#		1.0.4	-- HistoPlot: allow predefined binrange (2008-11-10)
#		1.0.5	-- HistoPlot: returns xmin, xmax, ymin, ymax of plot (2009-01-16)
#		1.0.6	-- ImagePlot: add_symbols added (2009-02-20)
#		1.1	-- new classes and functions: basic_options, wrapper_options, set_logging, get_from_NGAS;
#			   new/enhanced methods (adapted from JP) for AssociationBlock: get_rawHDUs, get_hdr, get_pro (2009-03-20)
#		1.1.1	-- new: find_fits, find_mcal; logging output to stdout (2009-04-03/2009-04-21)
#		1.1.2	-- minor change within ImagePlot (2009-05-15)
#		1.1.3	-- new: paper_size, get_fontsizes; image compression for ImagePlot (2009-05-27)
#		1.1.4	-- bug fix in ImagePlot (2009-07-24)
#		1.1.5	-- bug fix in find_fits; new font sizes for a5 and a6 (2009-08-06)
#		1.2	-- upgrade to numpy V1.3 and matplotlib V0.99 (2009-10-27)
#		1.3	-- BasicPlot added (2010-02-09)
#		1.3.1	-- new: get_avgkeyval, get_avgfromhdrlist (2010-04-22)
#		1.3.2	-- improved symbol plotting in ImagePlot (2010-09-20)
#		1.3.3	-- bug fix with xlim/ylim in ImagePlot [only visible in case image is small] (2011-03-10)
#		1.3.4	-- new: overexp, overexp_fromlist (2013-04-29)
#		1.3.5	-- ImagePlot: option title_length for method draw (2013-05-23)
#		1.3.6	-- AssociationBlock: new method get_mcal (2013-05-28)
#		1.3.7	-- ImagePlot: reduce number of x locators if there are more than 4 sub plots in x (2013-08-08)
#		1.3.8	-- AssociationBlock: new method get_proList (2013-12-05)
#		1.3.9	-- new: get_keyval (2013-12-06)
#		1.3.10	-- ImagePlot: use min/max values as cuts if cuts==();
#			   AssociationBlock: get_pro delivers only first product in case of multiple products 
#			   with same PRO.CATG (2014-01-30)
#		1.3.11	-- new: get_arr (2014-02-07)
#		1.3.12	-- usage of argparse module, Sect. 0.3 added (2014-03-07)
#		1.3.13	-- remove usage of has_key dictionary attribute (2014-03-11)
#		1.3.14	-- AssociationBlock: AB content accessable via attributes;
#		           new: get_wavearray (2014-03-27)
#		1.4	-- new: BarPlot (2014-04-16)
#		1.4.1	-- HistoPlot: introduce binnum, shift bins by half binsize for plotting (2015-03-31)
#		1.4.2	-- HistoPlot: avoid stepsize < 1 if binnum is too large (2015-06-09)
#		1.5	-- new: arr_MAD; improvements for binning in HistoPlot (2016-02-18)
#		1.5.1	-- BasicPlot: use drawstyle='steps-mid', was linestyle='steps' (2016-04-25)
#		1.5.2	-- ImagePlot: bug fixes for addrow/addcol (2017-03-20)
#		1.5.3	-- ImagePlot: allow non-default x and y labels (2017-10-11)
#		1.5.4	-- force legend to upper right (2018-05-18)
#		1.5.5	-- usage of astropy.io instead of pyfits (2018-06-12)
#
# CONTENTS:
#		1. input/output
#			def print_message:	[obsolete]
#			class get_options:	[obsolete]
#			class basic_options:	[obsolete]
#			class wrapper_options:	[obsolete]
#			def set_logging:	set logging level, uses standard module logging
#			def get_from_NGAS:	download file from NGAS
#			def find_fits:		find fits file that has specified header key values
#			def find_mcal:		find a calibration product with specified PRO.CATG
#			class ConfigFile:	representation of ASCII configuration file (DFOS style)
#			class AssociationBlock:	representation of AB contents
#			def get_confname:	get name of config file for importing script-specific configuration
#		2. mathematics/statistics
#			def histogram		returns histogram of 2D image [WH]
#			def residuals		returns residuals in Gaussian fit [WH]
#			def pval		returns Gauss function values [WH]
#			def arr_median		returns median of 2D image
#			def arr_stddev		returns standard deviation of 2D image
#			def arr_MAD		returns Median Absolute Deviation [WH]
#			def im_divide		image division taking care of division by zero
#		3. convenience functions for accessing fits header keys and data in fits files
#			def get_keyval		returns value of header key
#			def get_avgkeyval	returns average of two header key values
#			def get_avgfromhdrlist	returns average and rms of a keyword value across a list of headers
#			def get_wavearray	returns an array with wavelengths calculated from CRVAL, CRPIX, and CDELT
#			def get_arr		returns data array from HDU
#			def overexp		returns number of overexposed pixels in one frame
#			def overexp_fromlist	returns maximum number of overexposed pixels from a list of frames
#		4. QC plotting
#			def paper_size		returns paper size in inch for A4, A5, A6
#			def get_fontsizes	returns font sizes used for A4, A5, A6
#			class BasicPlot:	plots 1D data
#			class LinePlot:		plots cuts through 2D images
#			class HistoPlot:	plots histograms
#			class ImagePlot:	plots 2D images for all extensions
#			class SpecPlot:		plots 1D spectra
#			class BarPlot:		plots data with error bars
#			def plot_image:		plots a 2D image with optional color bar and source identification

_version_ = "1.5.5"

# =====================================================================================
# 0. initilization
# 0.1 import modules
# =====================================================================================

import sys					# interaction with Python interpreter
import os					# operating system services
import argparse					# for parsing command line arguments and options
from optparse import OptionParser		# for parsing command line arguments (deprecated)
import logging					# output of logging messages
import numpy					# methods for numerical arrays
from astropy.io import fits as pyfits		# fits file handling
import math					# mathematical functions
import scipy					# numerical functions
import scipy.optimize				# least square optimization
import pylab					# plotting module

# =====================================================================================
# 0.2 set plotting parameters
# =====================================================================================

# change some default plotting parameters
from matplotlib import rcParams
rcParams['font.size'] = 8
rcParams['lines.linewidth'] = 0.75
rcParams['axes.linewidth'] = 0.75
rcParams['legend.borderpad'] = 0.3
rcParams['legend.borderaxespad'] = 0.3
rcParams['legend.handletextpad'] = 0.4
rcParams['legend.labelspacing'] = 0.0
rcParams['legend.loc'] = 'upper right'
rcParams['patch.linewidth'] = 0.75

# colour definitions
col_raw1 = '0.00'		# 1st raw frame
col_raw2 = '0.40'		# last raw frame
col_mst = 'r'			# cuts through master frames
col_ref = 'g'			# reference frame
col_histo = 'g'			# histogram fit
col_row = (0.00, 0.00, 1.00)	# row average
col_col = (0.00, 0.50, 1.00)	# column average

# default colour map
col_map = pylab.cm.hot

# =====================================================================================
# 0.3 initialize command line parsers
# =====================================================================================

# basic parser with options used by all scripts
basic_parser = argparse.ArgumentParser(add_help=False)
basic_parser.add_argument('-a', dest='ab', metavar='AB', help='name of AB', required=True, default=None)

# parser used by wrapper scripts
wrapper_parser = argparse.ArgumentParser(parents=[basic_parser], add_help=False)
wrapper_parser.add_argument('-i', '--ingest', dest='ingest', action='store_true', default=False,
		help='ingest QC1 parameters')

# =====================================================================================
# 1. input/output
# =====================================================================================

def print_message(level, message):
	"print info, warning, and error messages"

	this_prog = os.path.basename(sys.argv[0])
	if level == 'I' or level == 'INF':
		print this_prog, '[INF]', message
	elif level == 'W' or level == 'WAR':
		print this_prog, '[WAR]', message
	elif level == 'E' or level == 'ERR':
		print this_prog, '[ERR]', message
	else:
		print this_prog, '[]', message

class get_options:
	"get command line options"

	def __init__(self, vers="0.0"):
		# define options and arguments
		parser = OptionParser(version = "%prog " + vers)
		#parser.add_option("-v", "--version", help="show version number and exit", dest="version", action="store_true", default=False)
		parser.add_option("-a", dest="ab", metavar="AB", help="name of AB [required]", default="None")
		parser.add_option("-e", dest="ext", metavar="EXT", help="fits extension [optional]", type=int, default=0)
		parser.add_option("-i", help="ingest QC1 parameters [optional]", dest="ingest", action="store_true", default=False)
		parser.add_option("-q", dest="calc_qc", help="calculate additional QC1 parameters [optional]", action="store_true", default=False)
		parser.add_option("-l", "--loglevel", dest="logging", metavar="LEVEL",
				help="logging level [optional], LEVEL='debug', 'info', 'warning', or 'error' [default='info']",
				default="info")

		# parse arguments
		(options, args) = parser.parse_args()

		if options.ab == "None":
			parser.print_help()
			sys.exit(2)

		self.ab = options.ab
		self.ingest = options.ingest
		self.ext = options.ext
		self.calc_qc = options.calc_qc
		self.logging = options.logging

class basic_options:
	"definition of command line options used by all scripts"

	def __init__(self, vers="0.0"):
		self.parser = OptionParser(version = "%prog " + vers)
		self.parser.add_option("-a", "--AB", dest="ab", metavar="AB", help="name of AB [required]", default="None")
		self.parser.add_option("-e", "--ext", dest="ext", metavar="EXT", help="fits extension [optional]", type=int, default=0)
		self.parser.add_option("-l", "--loglevel", dest="logging", metavar="LEVEL",
				help="logging level [optional], LEVEL='debug', 'info', 'warning', or 'error' [default='info']",
				default="info")
	def get(self):
		(options, args) = self.parser.parse_args()
		if options.ab == "None":
			self.parser.print_help()
			sys.exit(2)
		return options

class wrapper_options(basic_options):
	"""definition of additional options,
	typically used by wrapper scripts"""

	def __init__(self, vers="0.0"):
		basic_options.__init__(self, vers)
		self.parser.add_option("-i", "--ingest", help="ingest QC1 parameters [optional]", dest="ingest", 
				action="store_true", default=False)

def set_logging(level='info', format=''):
	"set logging level and output format"

	if format == '':
		logFMT = '%(module)15s [%(levelname)7s] %(message)s'
		#logFMT = '%(module)15s[%(process)5d] [%(levelname)7s] %(message)s'
	else:
		logFMT = format
	loglevel = {'debug': logging.DEBUG, 'info': logging.INFO, 'warning': logging.WARNING, 'error': logging.ERROR}
	logging.basicConfig(level=loglevel.get(level.lower(), logging.INFO), format=logFMT, stream=sys.stdout)

def get_from_NGAS(file, dir='', type='raw'):
	"download file from NGAS archive into local dir(rectory)"

	options = {'raw': '-f', 'header': '-H', 'product': '-c'}
	if not type in options.keys():
		logging.error('get_from_NGAS: file type not in ("raw", "header", "product"); cannot download file '+file)
	else:
		if dir == '':
			dir = os.getcwd()
		if os.path.exists(dir+'/'+file):
			logging.info('get_from_NGAS: '+dir+'/'+file+' already existing; not downloaded from NGAS')
		else:
			current_dir = os.getcwd()
			os.chdir(dir)
			os.system('ngasClient '+options[type]+' '+file)
			os.chdir(current_dir)

def find_fits(header_keys=[], key_values=[], dir=''):
	"""finds fits file in dir
	that has header keys with specified values"""

	# find all fits files in dir
	if dir == '':
		dir = os.getcwd()
	all_files = os.listdir(dir)
	fits_files = []
	for file in all_files:
		if file[-4:] == 'fits':
			fits_files.append(file)
	fits_files.sort()
	fits_files.reverse()

	# search fits files for header keys
	found_file = ''
	for file in fits_files:
		temp_HDU = pyfits.open(dir+'/'+file)
		temp_HDU.close()
		for idx in range(len(header_keys)):
			if not header_keys[idx] in temp_HDU[0].header:
				break
			if temp_HDU[0].header[header_keys[idx]] != key_values[idx]:
				break
		else:
			found_file = file
		if found_file != '':
			break

	return found_file

def find_mcal(pro_catg, dir=''):
	"find fits file in dir that has specified pro_catg"

	return find_fits(['HIERARCH ESO PRO CATG'], [pro_catg], dir)

class ConfigFile:
	"""representation of ASCII configuration file
	format: <TAG> <VALUE1> <VALUE2> ... """

	def __init__(self, file_name='', mult_tags=[]):
                conf_file = open(file_name, 'r')
                all_the_file=conf_file.read().splitlines()
                conf_file.close()

		self.content = {}
		for line in all_the_file:
			words = line.split()
			if len(words) == 0:
				continue
			if words[0][0] == '#':
				continue
			tag = words[0]
			if not tag in mult_tags:
				# only first value is taken
				if len(words) == 1:
					self.content[tag] = ''
				else:
					if words[1] == '#':
						self.content[tag] = ''
					else:
						self.content[tag] = os.path.expandvars(words[1])
			else:
				if not tag in self.content:
					self.content[tag] = []
				sublist = []
				for word in words[1:]:
					if word[0] == '#':
						continue
					sublist.append(os.path.expandvars(word))
				self.content[tag].append(sublist)


class AssociationBlock(ConfigFile):
        "representation of AB contents"
        def __init__(self, ab_name=''):
		if os.path.exists(ab_name):
			full_name = ab_name
		else:
			dfo_ab_dir = os.environ.get('DFO_AB_DIR')
			full_name = dfo_ab_dir + '/' + ab_name
		mult_tag_list = ['RAW_MATCH_KEY', 'RAWFILE', 'RASSOC', 'PRODUCTS', 'MCALIB',
				 'MASSOC', 'PARAM', 'WAITFOR', 'FURTHER_PS', 'FURTHER_GIF',
				 'FURTHER_PNG', 'AB_STATUS', 'RB_CONTENT', 'SOF_CONTENT' ]
		ConfigFile.__init__(self, file_name=full_name, mult_tags=mult_tag_list)
		# new with v1.3.14: define attributes
		for key in self.content:
			attr = key.lower()
			self.__dict__[attr] = self.content[key]
	
	def get_raw(self):
		"returns the list of raw HDUs"

		HDU_list = []
		for item in self.content['RAWFILE']:
			file = item[0]
			try:
				HDU_list.append(pyfits.open(file))
			except:
				logging.warning(file+' not found')
		return HDU_list

	def get_rawHDUs(self):
		"""returns a dictionary with RAW HDUs:
		dictionary keys are ARCFILE values, items are the HDUs"""

		HDU_list = self.get_raw()
		HDU_dict = {}
		for temp_HDU in HDU_list:
			arcfile = temp_HDU[0].header['ARCFILE']
			HDU_dict[arcfile] = temp_HDU
		return HDU_dict

	def get_hdr(self):
		"""returns a dictionary with RAW HDUs built from headers:
		dictionary keys are ARCFILE values, items are the HDUs"""

		HDU_dict = {}
		for item in self.content['RAWFILE']:
			file, ext = os.path.splitext(os.environ.get('DFO_HDR_DIR') + '/' + self.content['DATE'] + '/' + os.path.basename(item[0]))
			file += '.hdr'
			if os.path.exists(file):
				try:
					temp_HDU = pyfits.PrimaryHDU()
					hdrf = open(file, 'r')
					for line in hdrf.readlines():
						line = line.strip()
						if not line[0:3] == 'END':
							c=pyfits.Card().fromstring(line)
							mkey=c.key
							if line[0:8] == 'HIERARCH':
								mkey='HIERARCH ' + c.key
							try:
								temp_HDU.header.update(mkey, c.value, c.comment)
							except:
								logging.warning('Could not handle %s, trying without comment'%(c.key))
								try:
									temp_HDU.header.update(mkey, c.value)
								except:
									logging.warning('Still failed without comment, trying with truncated value')
									try:
										mvalue=c.value[0:(80-3-len(mkey)-2)]
										if c.value[0] == "'":
											mvalue=c.value[0:(80-3-len(mkey)-3)] + "'"
											temp_HDU.header.update(mkey, mvalue)
									except:
										logging.warning('Still failed with truncation, giving up')
					hdrf.close()
					arcfile = temp_HDU.header['ARCFILE']
					HDU_dict[arcfile] = temp_HDU
				except:
					logging.error('error reading ' + file)
			else:
				logging.warning(file + ' not found')
		return HDU_dict

	def get_pro(self):
		"""returns a dictionary with product HDUs:
		dictionary keys are PRO.CATG values, items are the HDUs"""
		
		HDU_dict = {}

		if os.path.exists(self.content['PROD_PATH']):
			all_files = os.listdir(self.content['PROD_PATH'])
			files = []
			for file in all_files:
				if self.content['PROD_ROOT_NAME'] in file and file[-4:] == 'fits':
					files.append(file)

			files.sort()
			for file in files:
				full_name = self.content['PROD_PATH'] + '/' + file
				temp_HDU = pyfits.open(full_name)
				pro_catg = temp_HDU[0].header['HIERARCH ESO PRO CATG']
				if not pro_catg in HDU_dict:
					HDU_dict[pro_catg] = pyfits.open(full_name)
				temp_HDU.close()

		if len(HDU_dict) == 0 and os.path.exists(os.environ.get('DFO_CAL_DIR') + '/' + self.content['DATE']):
			all_files = os.listdir(os.environ.get('DFO_CAL_DIR') + '/' + self.content['DATE'])
			hfFiles = ''
			tfiles = []
			for file in all_files:
				if file[-3:] == 'hdr' or file[-4:] == 'fits':
					hfFiles += ' ' + os.environ.get('DFO_CAL_DIR') + '/' + self.content['DATE'] + '/' + file
			for file in os.popen('dfits ' + hfFiles + ' | fitsort -d PIPEFILE ORIGFILE | grep ' + self.content['PROD_ROOT_NAME'] + ' | awk \'{print $3,$1}\' | awk \'{print $1}\'').readlines():
				tfiles.append(os.environ.get('DFO_CAL_DIR') + '/' + self.content['DATE'] + '/' + os.path.basename(file.strip()))
				get_from_NGAS(os.path.basename(tfiles[-1]), os.environ.get('DFO_CAL_DIR') + '/' + self.content['DATE'], 'product')
			tfiles.sort()
			files = set(tfiles)
			for file in files:
				temp_HDU = pyfits.open(file)
				pro_catg = temp_HDU[0].header['HIERARCH ESO PRO CATG']
				HDU_dict[pro_catg] = pyfits.open(file)
				temp_HDU.close()

		if len(HDU_dict) == 0 and os.path.exists(os.environ.get('DFO_SCI_DIR') + '/' + self.content['DATE']):
			all_files = os.listdir(os.environ.get('DFO_SCI_DIR') + '/' + self.content['DATE'])
			hfFiles = ''
			tfiles = []
			for file in all_files:
				if file[-3:] == 'hdr' or file[-4:] == 'fits':
					hfFiles += ' ' + os.environ.get('DFO_SCI_DIR') + '/' + self.content['DATE'] + '/' + file
			for file in os.popen('dfits ' + hfFiles + ' | fitsort -d PIPEFILE ORIGFILE | grep ' + self.content['PROD_ROOT_NAME'] + ' | awk \'{print $3,$1}\' | awk \'{print $1}\'').readlines():
				tfiles.append(os.environ.get('DFO_SCI_DIR') + '/' + self.content['DATE'] + '/' + os.path.basename(file.strip()))
				get_from_NGAS(os.path.basename(tfiles[-1]), os.environ.get('DFO_SCI_DIR') + '/' + self.content['DATE'], 'product')
			tfiles.sort()
			files = set(tfiles)
			for file in files:
				temp_HDU = pyfits.open(file)
				pro_catg = temp_HDU[0].header['HIERARCH ESO PRO CATG']
				HDU_dict[pro_catg] = pyfits.open(file)
				temp_HDU.close()
		return HDU_dict

	def get_proList(self):
		"""returns a list of product HDUs instead of a dictionary;
		this is useful if product files with the same PRO.CATG exist"""

		HDU_list = []
		if os.path.exists(self.content['PROD_PATH']):
			all_files = os.listdir(self.content['PROD_PATH'])
			files = []
			for file in all_files:
				if self.content['PROD_ROOT_NAME'] in file and file[-4:] == 'fits':
					files.append(file)
			files.sort()
			for file in files:
				full_name = self.content['PROD_PATH'] + '/' + file
				HDU_list.append(pyfits.open(full_name))
		return HDU_list

	def get_mcal(self):
		"""returns a dictionary with master calib HDUs:
		dictionary keys are PRO.CATG values, items are the HDUs"""

		HDU_dict = {}
		if 'MCALIB' in self.content:
			for line in self.content['MCALIB']:
				if line[0] != 'NONE':
					file = line[1]
					pro_catg = line[2]
					HDU_dict[pro_catg] = pyfits.open(file)
		return HDU_dict

# get name of configuration file
# dependend on RAW_TYPE and RAW_MATCH_KEYs in AB
def get_confname(config_files, AB):
	"get name of configuration file that contains setting-dependent configuration"

	match_keys = []
	if 'RAW_MATCH_KEY' in AB.content:
		for key in AB.content['RAW_MATCH_KEY']:
			match_keys.append(key[0])

	module_name = ''
	for config in config_files:
		if config[1] == AB.content['RAW_TYPE'] or config[1] == 'ANY':
			for raw_match in config[2:]:
				if raw_match == 'ANY':
					module_name = config[0]
					break
				if not raw_match in match_keys:
					break
			else:
				module_name = config[0]
			if module_name != '':
				break
	return module_name

# =====================================================================================
# 2. mathematical and statistical functions
# =====================================================================================

# histogram function
def histogram(a, bins):
	# Note that argument names below are reverse of the
	# searchsorted argument names
	n = numpy.searchsorted(numpy.sort(a), bins)
	n = numpy.concatenate([n, [len(a)]])
	return n[1:]-n[:-1]

# residuals in Gaussian fit
def residuals(p, y, x):
	A,mu,sigma = p
	err = y-A*scipy.exp(-(x-mu)*(x-mu)/2/sigma/sigma)
	return err

# calculate Gauss, based on current parameter vector
def pval(x, p):
	return p[0]*scipy.exp(-(x-p[1])*(x-p[1])/2/p[2]/p[2])

# get median of a numpy array, much faster than scipy.median
def arr_median(arr):
	flatarr = arr.ravel()
	indarr = flatarr.argsort()
	index = indarr.shape[0] // 2
	if indarr.shape[0] % 2 != 0:
		return flatarr[indarr[index]]
	else:
		return (flatarr[indarr[index]] + flatarr[indarr[index-1]]) / 2.

# calculates standard deviation of a numpy array with clipping
def arr_stddev(arr, thres=5.0, median=None):

	# calculate median only if it is not given
	if median == None:
		median = arr_median(arr)
	sigma = arr.std()
	if sigma == 0.0:
		return sigma
	flatarr = arr.ravel()
	# clip those values outside median +/- thres*sigma
	flatarr = numpy.compress(numpy.less(flatarr, median+thres*sigma), flatarr)
	flatarr = numpy.compress(numpy.greater(flatarr, median-thres*sigma), flatarr)
	# return standard deviation of clipped array
	sigma = flatarr.std()
	return sigma

# calculates Median Absolute Deviation [WH]
def arr_MAD(arr, median=None):

	# calculate median only if it is not given
	if median == None:
		median = arr_median(arr)
	# subtract median from image
	flatarr = arr.ravel() - median
	# get absolute deviations
	absdev = numpy.fabs(flatarr)
	# median of absolute deviations
	mad = arr_median(absdev)

	return mad

# image division taking care of division by zero
def im_divide(dividend, divisor):
	"divide two 2d images, replace inf and nan with 1.0"

	orig_handling = numpy.seterr(divide='ignore', invalid='ignore', under='ignore') # do not print error messages
	result = dividend / divisor
	result[numpy.isinf(result)] = 1.0 # handle infinity and NaN events
	result[numpy.isnan(result)] = 1.0
	# restore original error handling
	numpy.seterr(divide=orig_handling['divide'], invalid=orig_handling['invalid'], under=orig_handling['under'])
	return result

# =====================================================================================
# 3. convenience functions to access fits header keys and data
# =====================================================================================

# get values of header key
def get_keyval(header, key, default=-1):
	"""returns value of header key, uses default value if keyword is not existing"""

	if key in header:
		value = header[key]
	else:
		value = default
	return value

# average of two header key values
def get_avgkeyval(header, key1, key2, default=0):
	"""calculates average of two header key values;
	example: key1='HIERARCH ESO TEL AMBI FWHM START', key2='HIERARCH ESO TEL AMBI FWHM END' """

	if key1 in header and key2 in header:
		value = 0.5*(header[key1]+header[key2])
	elif key1 in header and not key2 in header:
		value = header[key1]
	elif not key2 in header and key2 in header:
		value = header[key2]
	else:
		value = default
	return value

# average and rms from list of headers
def get_avgfromhdrlist(HDUlist, key1, key2='', default=-99, ext=0):
	"""calculates average and rms of a keyword value across a list of headers"""

	values = []
	for HDU in HDUlist:
		hdr = HDU[ext].header
		if key2 != '':
			if key1 in hdr or key2 in hdr:
				values.append(get_avgkeyval(hdr, key1, key2, default=default))
		else:
			if key1 in hdr:
				values.append(hdr[key1])
	if len(values) == 0:
		return default, -1
	elif len(values) == 1:
		return scipy.mean(values), -1
	else:
		return scipy.mean(values), scipy.std(values, ddof=1)
		# for scisoft use: return scipy.mean(values), scipy.std(values)
	# scipy.std = sqrt(1/(N-ddof) * sum(x_i - avg(x)))
	# for scisoft installation: ddof=1 always

# calculate wavelengths for a spectrum in binary fits file
def get_wavearray(hdr, wshape=-1, offset=0., axis=1):
	"""returns an array with wavelengths calculated from CRVAL<axis>, CRPIX<axis>, and CDELT<axis>
	length of array is either NAXIS<axis> (if wshape<=0) or wshape"""

	crval = get_keyval(hdr, 'CRVAL'+str(axis), default=1)
	crpix = get_keyval(hdr, 'CRPIX'+str(axis), default=1)
	cdelt = get_keyval(hdr, 'CDELT'+str(axis), default=1)
	if wshape <=0:
		wshape = get_keyval(hdr, 'NAXIS'+str(axis), default=1)
	xpix = numpy.arange(0, wshape)
	wavelengths = xpix * cdelt + crval - cdelt*(crpix - 1. + offset)

	return wavelengths

# get data array from HDU
def get_arr(HDU, ext=0, default=0):
	"""return data array (image) from HDU extension,
	replaces NAN values with default value"""

	arr = HDU[ext].data
	arr[numpy.isnan(arr)] = default

	return arr

# returns number of over-exposed pixels
def overexp(HDU, ext=0, threshold=60000):
	"""returns number of over-exposed pixels"""

	num_over = 0
	flatarr = HDU[ext].data.ravel()
	flatarr = numpy.compress(numpy.greater_equal(flatarr, threshold), flatarr)
	num_over = flatarr.shape[0]

	return num_over

# returns maximum number of over-exposed pixels from a list of frames
def overexp_fromlist(HDUlist, ext=0, threshold=60000):
	"""returns maximum number of over-exposed pixels from a list of frames"""

	max_over = 0
	for HDU in HDUlist:
		num_over = overexp(HDU, ext=ext, threshold=threshold)
		if num_over > max_over:
			max_over = num_over
	
	return max_over

# =====================================================================================
# 4. QC plotting routines
# =====================================================================================

def paper_size(size = 'a4'):
	"gives paper size of A4, A5, A6 in inch"

	size_in_inch = {'a4': (11.7,8.3), 'a5': (8.3,5.8), 'a6': (5.8,4.1)}
	return size_in_inch.get(size.lower(), (11.7,8.3))

def get_fontsizes(paper_size = 'a4'):
	"definition of fontsizes"

	if paper_size == 'a6':
		title_font = 8
		subtitle_font = 6
		footer_font = 5
	elif paper_size == 'a5':
		title_font = 11
		subtitle_font = 7
		footer_font = 7
	else:
		title_font = 14
		subtitle_font = 10
		footer_font = 8
	return (title_font, subtitle_font, footer_font)

class BasicPlot:
	"plots 1D data"

	def __init__(self):
		self.plotdata = []
		return
	def add_data(self, data, name='', lin_col='0.0'):
		self.plotdata.append((data, lin_col, name))
		return
	def draw(self, xrange=(), yrange=(), xarray=[]):
		ymin1, ymin2, ymax1, ymax2 = [], [], [], []
		xmax1 = []
		for lindef in self.plotdata:
			line = lindef[0]
			if len(xarray) == 0:
				#pylab.plot(line, color=lindef[1], drawstyle='steps-pre', label=lindef[2])
				pylab.plot(line, color=lindef[1], drawstyle='steps-mid', label=lindef[2])
				#pylab.plot(line, color=lindef[1], drawstyle='steps-post', label=lindef[2])
			else:
				#pylab.plot(xarray, line, color=lindef[1], drawstyle='steps-pre', label=lindef[2])
				pylab.plot(xarray, line, color=lindef[1], drawstyle='steps-mid', label=lindef[2])
				#pylab.plot(xarray, line, color=lindef[1], drawstyle='steps-post', label=lindef[2])

			# prepare for propper scaling
			mu = arr_median(line)
			sigma = arr_stddev(line, 5.0, mu)
			ymin1.append(mu-5*sigma)
			ymax1.append(mu+5*sigma)
			ymin2.append(line.min())
			ymax2.append(line.max())
			xmax1.append(line.shape[0])

		# scaling in x
		if len(xrange) == 0 and len(xarray) == 0:
			xmin = -10
			xmax = numpy.array(xmax1).max() + 10
		else:
			if len(xrange) != 0:
				xmin = xrange[0]
				xmax = xrange[1]
			else:
				xmin = xarray[0]
				xmax = xarray[-1]
		pylab.xlim(xmin, xmax)

		# scaling in y
		if len(yrange) == 0:
			ymin = numpy.array([numpy.array(ymin1).min(), numpy.array(ymin2).min()]).max()
			ymax = numpy.array([numpy.array(ymax1).max(), numpy.array(ymax2).max()]).min()
			if ymin == ymax:
				ymin = ymin - abs(ymin)
				ymax = ymax + abs(ymax)
			ydiff = ymax - ymin
			ymin = ymin - 0.1 * ydiff
			ymax = ymax + 0.3 * ydiff
		else:
			ymin = yrange[0]
			ymax = yrange[1]
		pylab.ylim(ymin, ymax)

		return xmin, xmax, ymin, ymax

class LinePlot(BasicPlot):
	"plots several rows/columns of 2D images"

	def __init__(self):
		self.plotdata = []
		return
	def get_line(self, image, lin, axis):
		if axis == 'row':
			return image[lin,:]
		if axis == 'col':
			return image[:,lin]
	def get_avg(self, image, axis):
		if axis == 'row':
			return numpy.sum(image,axis=0) / float(image.shape[0])
		if axis == 'col':
			return numpy.sum(image,axis=1) / float(image.shape[1])
	def add_line(self, image, axis='row', line='@0', name='', lin_col='0.0'):
		if line[0] == '@':
			data = self.get_line(image, int(line[1:]), axis)
		if line == 'avg':
			data = self.get_avg(image, axis)
		self.plotdata.append((data, lin_col, name))
		return

class HistoPlot:
	"plots histograms of 2D images and fits to histograms"

	def __init__(self, logplot=False, stepsize=0, binrange=5.0, binnum=0):
		self.histo_def = []
		self.logplot = logplot
		self.stepsize = stepsize
		self.binrange = binrange
		self.binnum = binnum
		self.xrange = xrange
		return
	def add_histo(self, image, fit=False, name='', lin_col=0.0):
		self.histo_def.append((image, fit, name, lin_col))
		return
	def draw(self):
		ymin, ymax = [], []
		image = self.histo_def[0][0]
		try:
			lenrange = len(self.binrange)
		except:
			lenrange = 0
		if lenrange == 2:
			if self.binnum > 0:
				# binnum > 0: ignore self.stepsize
				stepsize = (self.binrange[1] - self.binrange[0]) / self.binnum
				if stepsize < 1.0:
					# avoid empty bins if stepsize is too small
					stepsize = 1.0
				xbins = numpy.arange(self.binrange[0], self.binrange[1], stepsize)
			elif self.stepsize > 0:
				xbins = numpy.arange(self.binrange[0], self.binrange[1], self.stepsize)
			else:
				sigma = 5.0 * arr_stddev(image)
				xbins = numpy.arange(self.binrange[0], self.binrange[1], 0.05*sigma)
		else:
			mean = arr_median(image)
			sigma = self.binrange * arr_stddev(image, median=mean)
			if self.binnum > 0:
				# binnum > 0: ignore self.stepsize
				stepsize = 2 * sigma / self.binnum
				if stepsize < 1.0:
					# avoid empty bins if stepsize is too small
					stepsize = 1.0
				xbins = numpy.arange(mean-sigma, mean+sigma, stepsize)
			elif self.stepsize > 0:
				if self.stepsize % 1 == 0:
					#if stepsize is integer make sigma integer
					sigma = math.ceil(sigma)
				xbins = numpy.arange(mean-sigma,mean+sigma,self.stepsize)
			else:
				xbins = numpy.arange(mean-sigma,mean+sigma,0.05*sigma)
		for histo in self.histo_def:
			image = histo[0]
			hm0 = histogram(image.ravel(), xbins)
			# shift histogram for proper plotting 
			hm = numpy.roll(hm0, 1)
			hm[0] = hm [1]
			xplot = xbins
			if not histo[1]:
				if self.logplot:
					hm = numpy.where(hm == 0.0, 1.0, hm)
					hmlog = numpy.log10(hm)
					ymin.append(min(hmlog))
					ymax.append(max(hmlog))
					pylab.plot(xplot, hmlog, linestyle='steps', label=histo[2], color=histo[3])
					# pylab.semilogy often raises exceptions, therefore not used
					#pylab.semilogy(xbins, hm, linestyle='steps', label=histo[2], color=histo[3])
				else:
					ymin.append(min(hm))
					ymax.append(max(hm))
					pylab.plot(xplot, hm, linestyle='steps', label=histo[2], color=histo[3])
			else:
				xbinf = xbins[:-1]
				hmf = hm[0:hm.size-1]
				#A, mu, sigma = max(hmf), arr_median(image), arr_stddev(image)
				A = max(hmf)
				mu = arr_median(image)
				sigma = arr_stddev(image, median=mu)
				p0 = [A, mu, sigma]
				plsq = scipy.optimize.leastsq(residuals, p0, args=(hmf, xbinf))
				A, mu, sigma = plsq[0]
				if self.logplot:
					valarr = pval(xbins,plsq[0])
					valarr = numpy.where(valarr == 0.0, 1.0, valarr)
					logarr = numpy.log10(valarr)
					pylab.plot(xplot, logarr, label=histo[2], color=histo[3])
				else:
					pylab.plot(xplot, pval(xbins,plsq[0]), label=histo[2], color=histo[3])
		# scaling for plot
		pylab.xlim(xplot.min(), xplot.max())
		diff = max(ymax) - min(ymin)
		pylab.ylim(min(ymin), max(ymax) + 0.05 * diff)
		return xplot.min(), xplot.max(), min(ymin), max(ymax) + 0.05 * diff

class ImagePlot:
	"displays 2D images"

	def __init__(self, image_map=(1,1)):
		self.image_list = []
		self.symbol_list = []
		self.image_map = image_map
	def add_image(self, image, pos, name, cuts=(1.0,), xrange=(), yrange=(), compress=1, cmap=col_map, xlabel='X', ylabel='Y'):
		if compress > 1:
			temp_ima = numpy.take(image, range(0, image.shape[0], compress), axis=0)
			image = numpy.take(temp_ima, range(0, image.shape[1], compress), axis=1)
		self.image_list.append((image, pos, name, cuts, xrange, yrange, compress, cmap, xlabel, ylabel))
	def add_symbols(self, xvalues, yvalues, pos, symbol='o', symsize=5, edgecolor='b', edgewidth=0.75, facecolor='none'):
		self.symbol_list.append((xvalues, yvalues, pos, symbol, symsize, edgecolor, edgewidth, facecolor))
	def draw(self, addcol=(), addrow=(), addsquare=(), all_labels=True, title_length='long', 
			aspect='auto', colorbar=False, figsize='a4'):
		# move all_labels, aspect, figsize to __init__ ???
		# add color map
		fs_title, fs_sub, fs_foot = get_fontsizes(figsize)
		for image_def in self.image_list:
			image = image_def[0]
			pos = image_def[1]
			name = image_def[2]
			cuts = image_def[3]
			xrange = image_def[4]
			yrange = image_def[5]
			compress = image_def[6]
			cmap = image_def[7]
			xlabel = image_def[8]
			ylabel = image_def[9]
			if len(xrange) == 2:
				image = image[:,xrange[0]:xrange[1]]
			else:
				xrange = [0, image.shape[1]]
			if len(yrange) == 2:
				image = image[yrange[0]:yrange[1],:]
			else:
				yrange = [0, image.shape[0]]
			if len(cuts) == 1:
				median = arr_median(image)
				sigma = arr_stddev(image, thres=5.0, median=median)
				lcut = median-cuts[0]*sigma
				hcut = median+cuts[0]*sigma
			elif len(cuts) == 2:
				lcut = cuts[0]
				hcut = cuts[1]
			else:
				# use minimum and maximum
				lcut = image.min()
				# force hcut slightly above max value, otherwise pixel with max value becomes black
				# this seems to be a bug in matplotlib
				hcut = image.max()
				hcut = hcut + 0.01 * abs(hcut)
			pylab.subplot(self.image_map[1], self.image_map[0], pos)
			pylab.imshow(image, cmap, vmin=lcut, vmax=hcut, origin='lower', 
					aspect=aspect, interpolation='nearest')
			if colorbar:
				pylab.colorbar()
			if self.image_map[1] > 2:
				if pos > (self.image_map[1] - 1) * self.image_map[0]:
					pylab.xlabel(xlabel, size=fs_foot)
			else:
				pylab.xlabel(xlabel, size=fs_foot)
			if self.image_map[0] > 3:
				if (pos - 1) % self.image_map[0] == 0:
					pylab.ylabel(ylabel, size=fs_foot)
			else:
				pylab.ylabel(ylabel, size=fs_foot)
			if len(addrow) == 1:
				addline = addrow[0] / compress
				pylab.plot([-0.5, image.shape[1]-0.5], [addline-yrange[0], addline-yrange[0]], 'b-')
				# until 1.5.1:
				#addline = addrow[0] // compress
				#pylab.plot([0, image.shape[1]], [addline-yrange[0], addline-yrange[0]], 'b-')
			if len(addcol) == 1:
				addline = addcol[0] / compress
				pylab.plot([addline-xrange[0], addline-xrange[0]], [-0.5, image.shape[0]-0.5], 'b-')
				# until 1.5.1:
				#addline = addcol[0] // compress
				#pylab.plot([addline-xrange[0], addline-xrange[0]], [0, image.shape[0]], 'b-')
			if len(addsquare) == 4:
				square =(addsquare[0] // compress - 1, addsquare[1] // compress - 1, addsquare[2] // compress, addsquare[3] // compress)
				# adding -1 is formally wrong here but avoids having the square in an enlargement plot
				#square =(addsquare[0] // compress, addsquare[1] // compress, addsquare[2] // compress, addsquare[3] // compress)
				pylab.plot([square[0]-xrange[0],square[2]-xrange[0],square[2]-xrange[0],
						square[0]-xrange[0],square[0]-xrange[0]],
					   [square[1]-yrange[0],square[1]-yrange[0],square[3]-yrange[0],
					        square[3]-yrange[0],square[1]-yrange[0]], 'b-')
			# overplot symbols
			if len(self.symbol_list) > 0:
				for sym_def in self.symbol_list:
					if pos == sym_def[2]:
						sym_posx = []
						sym_posy = []
						for idx in range(len(sym_def[0])):
							sym_posx.append((sym_def[0][idx] // compress) - xrange[0])
							sym_posy.append((sym_def[1][idx] // compress) - yrange[0])
						pylab.plot(sym_posx, sym_posy, linewidth = 0,
							marker=sym_def[3], markersize=sym_def[4], 
							markeredgecolor=sym_def[5], markeredgewidth=sym_def[6],
							markerfacecolor=sym_def[7])
			
			pylab.xlim(0, image.shape[1])
			pylab.ylim(0, image.shape[0])
			# tick marks and locators, complicated if plotted image different from original image
			# x axis
			locs, labels = pylab.xticks()
			base = locs[1] // compress
			locs = numpy.arange((xrange[0]//base)*base, xrange[1]+base, base) - xrange[0]
			labels = []
			if all_labels:
				for locator in locs:
					labels.append('%i' % (locator*compress+xrange[0],))
			else:
				if self.image_map[1] > 3 and pos <= (self.image_map[1] - 1) * self.image_map[0]:
					for locator in locs:
						labels.append('')
				else:
					for locator in locs:
						labels.append('%i' % (locator*compress+xrange[0],))
			# v1.3.7
			if self.image_map[0] > 4 and len(labels) > 6:
				# plot only every 2nd locator
				for idx in range(1, len(labels), 2):
					labels[idx] = ''
			pylab.xticks(locs, labels, size=fs_foot)
			# y axis
			locs, labels = pylab.yticks()
			base = locs[1]
			locs = numpy.arange((yrange[0]//base)*base, yrange[1]+base, base) - yrange[0]
			labels = []
			if all_labels:
				for locator in locs:
					labels.append('%i' % (locator*compress+yrange[0],))
			else:
				if self.image_map[0] > 3 and ((pos - 1) % self.image_map[0] != 0):
					for locator in locs:
						labels.append('')
				else:
					for locator in locs:
						labels.append('%i' % (locator*compress+yrange[0],))
			pylab.yticks(locs, labels, size=fs_foot)
			# finalizing
			pylab.xlim(-0.5, image.shape[1]-0.5)
			pylab.ylim(-0.5, image.shape[0]-0.5)
			if title_length == 'short':
				title = '%s [%5.2f,%5.2f]' % (name, lcut, hcut)
			else:
				title = '%s [cuts=%5.2f,%5.2f]' % (name, lcut, hcut)
			pylab.title(title, size=fs_foot)

class SpecPlot:
	"plots 1D spectra, y scaling different from BasicPlot"

	def __init__(self):
		self.spec = []
		return
	def add_spec(self, spec, name='', lin_col=0.0):
		self.spec.append((spec, lin_col, name))
		return
	def draw(self, xrange=[], yrange=[], xarray=[]):
		xmax_list = []
		ymax_list = []
		ymin_list = []
		for spec in self.spec:
			if len(xarray) == 0:
				pylab.plot(spec[0], color=spec[1], linestyle='steps', label=spec[2])
				xmax_list.append(len(spec[0]))
			else:
				pylab.plot(xarray, spec[0], color=spec[1], linestyle='steps', label=spec[2])
			mu = arr_median(spec[0])
			sigma = arr_stddev(spec[0], 5.0, mu)
			ymax1 = mu+5*sigma
			ymax2 = 1.5 * mu
			ymax_list.append(max([ymax1,ymax2]))
			ymin_list.append(mu-5*sigma)
		# scaling
		if len(xarray) == 0 and len(xrange) == 0:
			xmin = 0
			xmax = max(xmax_list)
		if len(xrange) != 0:
			xmin = xrange[0]
			xmax = xrange[1]
		if len(xarray) != 0 and len(xrange) == 0:
			xmin = xarray.min()
			xmax = xarray.max()
		if len(yrange) != 0:
			ymin = yrange[0]
			ymax = yrange[1]
		else:
			ymin = min(ymax_list)
			if ymin > 0:
				ymin = 0
			ymax = max(ymax_list)
		pylab.xlim(xmin, xmax)
		pylab.ylim(ymin, ymax)
		return xmin, xmax, ymin, ymax

class BarPlot:
	"plots multiple data sets with error bars"

	def __init__(self):
		self.data = []
		return
	def add_data(self, xdata, ydata, errors, name='', colour='0.0'):
		self.data.append((xdata, ydata, errors, name, colour))
		return
	def draw(self, xrange=[], yrange=[]):
		for data_set in self.data:
			xdat = data_set[0]
			ydat = data_set[1]
			err = data_set[2]
			lab = data_set[3]
			col = data_set[4]
			pylab.errorbar(xdat, ydat, yerr=err, fmt=None, ecolor=col, capsize=2, label=lab)
		if xrange != []:
			pylab.xlim(xrange[0], xrange[1])
		if yrange != []:
			pylab.ylim(yrange[0], yrange[1])
		return

# =====================================================================================
# 4. handle standalone call
# =====================================================================================

if __name__ == '__main__':
	parser = argparse.ArgumentParser(description=
			'Module with definitions, functions, and classes for pyQC. No standalone usage.')
	parser.add_argument('--version', action='version', version='%(prog)s ' + _version_)
	parser.parse_args()
	set_logging(level='info')
	logging.info('Module with definitions, functions, and classes for pyQC. No standalone usage. Exit.')
	sys.exit(0)

