import numpy as np
from astropy import coordinates as c
from astropy import units as u
import math, sys, os
from matplotlib import pyplot as plt
from astropy.io import ascii, fits
from scipy import interpolate as interp

def main():


    # Telluric before or after OB?
    after = True

    # Method for selecting best telluric, see `pick_telluric()`
    method = 'avg'

    # Hard-coded limits for RUWE/NSS
    max_ruwe = 1.4
    filter_nss = True

    if filter_nss:
        max_nss = 0
    else:
        max_nss = 999

    d = ascii.read('targets.txt')
    target_name = d['target'].data
    ra = d['ra_deg'].data
    dec = d['dec_deg'].data
    ob_duration = d['ob_duration'].data
    ao_type = d['ao_type'].data
    spxw = d['spxw'].data
    spgw = d['spgw'].data
    max_am_ob = d['max_am_ob'].data
    max_am_tel = d['max_am_tel'].data

    n = len(ra)

    # Unit conversion
    dr = np.pi/180.0
    rd = 180.0/np.pi
    ra = [x*dr for x in ra]
    dec = [x*dr for x in dec]

    # Valid values for spgw
    valid_spgw = ('J_low', 'J_short', 'J_middle', 'J_long',
             'H_low', 'H_short', 'H_middle', 'H_long',
             'K_low', 'K_short', 'K_middle', 'K_long')

    # Test parameters are in range, single value
    for i in range(0, n):
        if not 0.0 <= ra[i] <= 2.0*np.pi:
            raise ValueError('RA (deg) not in range for entry {}'.format(i))
        if not -np.pi/2. <= dec[i] <= np.pi/2.0:
            raise ValueError('Dec (deg) not in range for entry {}'.format(i))
        if not 1.0 <= ob_duration[i] <= 180.0:
            raise ValueError('OB duration (min) not in range for entry {}'.format(i))
        if ao_type[i] not in ('AO', 'noAO'):
            raise ValueError('AO type must be either "noAO" or "AO" for entry {}'.format(i))
        if spxw[i] not in ('25mas', '100mas', '250mas'):
            raise ValueError('spxw must be either "25mas", "100mas", or "250mas" for entry {}'.format(i))
        if spgw[i] not in valid_spgw:
            raise ValueError('spgw must be a valid instrument setup for entry {}'.format(i))
        if not 1.0 <= max_am_ob[i] <= 2.9:
            raise ValueError('Maximum OB airmass must be between 1.0 and 2.9 for entry {}'.format(i))
        if not 1.0 <= max_am_tel[i] <= 2.9:
            raise ValueError('Maximum telluric airmass must be between 1.0 and 2.9 for entry {}'.format(i))
        
        if (ao_type[i] == 'AO') and max_am_ob[i] > 1.9:
            print('Maximum airmass with AO is 1.9, setting max_am_ob = 1.9 for entry {}'.format(i))
            max_am_ob[i] = 1.9
        if (ao_type[i] == 'AO') and max_am_tel[i] > 1.9:
            print('Maximum airmass with AO is 1.9, setting max_am_tel = 1.9 for entry {}'.format(i))
            max_am_tel[i] = 1.9

    # Telluric standard sample
    data = ascii.read('telluric-sample.txt')
    tel_sample_name = data['Name'].data
    tel_sample_ra = data['RA'].data * dr
    tel_sample_dec = data['Dec'].data * dr
    tel_sample_source = data['Catalogue'].data
    tel_sample_ruwe = data['RUWE'].data
    tel_sample_nss = data['NSS'].data
    tel_sample_rpmag = data['Gaia_RP'].data
    tel_sample_jmag = data['J'].data
    tel_sample_hmag = data['H'].data
    tel_sample_kmag = data['K'].data
    tel_sample_neighbours = data['N_neighbours'].data
    n_tel = len(tel_sample_ra)

    # Create interpolation objects for DIT, SNR, max value
    dit_ao, snr_ao, max_ao, dit_noao, snr_noao, max_noao = load_etc_prediction()
    
    for i in range(0, n):
        with open('output-list-{}.txt'.format(target_name[i].replace(' ','_')), 'w') as f:

            band = spgw[i].split('_')[0]
            if band == 'J': tel_sample_mag = tel_sample_jmag
            if band == 'H': tel_sample_mag = tel_sample_hmag
            if band == 'K': tel_sample_mag = tel_sample_kmag

            # Perform interpolation
            if ao_type[i] == 'AO':
                dit = dit_ao[spxw[i]][spgw[i]](tel_sample_mag)
                snr = snr_ao[spxw[i]][spgw[i]](tel_sample_mag)
                max_val = max_ao[spxw[i]][spgw[i]](tel_sample_mag)
            else:
                dit = dit_noao[spxw[i]][spgw[i]](tel_sample_mag)
                snr = snr_noao[spxw[i]][spgw[i]](tel_sample_mag)
                max_val = max_noao[spxw[i]][spgw[i]](tel_sample_mag)
                
            fig, ax = plt.subplots(3, figsize=(5.0, 5.0), dpi=150, sharex=True, sharey=True)

            for j, sample in enumerate(('xshooter-B-verified', 'xshooter-B', 'gaiaDR3-solar-analogues')):
                indx = np.where((tel_sample_source == sample) & \
                                np.isfinite(dit) & \
                                (tel_sample_rpmag <= 11.0) & \
                                (tel_sample_ruwe < max_ruwe) & \
                                (tel_sample_nss <= max_nss) & \
                                (tel_sample_neighbours == 0))[0]
                
                #return tel_ind, min_delta_am, max_delta_am, delta_am, ha_start, visibility
                result = pick_telluric(ra[i], dec[i],
                                       tel_sample_ra[indx], tel_sample_dec[indx],
                                       ob_duration[i], after=after,
                                       max_am_OB=max_am_ob[i], max_am_tel=max_am_tel[i],
                                       method='avg')

                if j == 0:
                    print('# Target: {}, RA: {:.8f}, Dec: {:.8f}, grating: {}, plate scale: {}, AO mode: {}'.format(target_name[i], ra[i]*rd, dec[i]*rd, spgw[i], spxw[i], ao_type[i]))
                    f.write('# Target: {}, RA: {:.8f}, Dec: {:.8f}, grating: {}, plate scale: {}, AO mode: {}\n'.format(target_name[i], ra[i]*rd, dec[i]*rd, spgw[i], spxw[i], ao_type[i]))

                    print('Name                      \tRA          \tDec          \tBand\tMag\tDIT\tSNR\tMax\tRP mag\tAO mode\tMin AM\tMax AM\tAvg AM\tSD AM\tSource')
                    f.write('Name                      \tRA          \tDec          \tBand\tMag\tDIT\tSNR\tMax\tRP mag\tAO mode\tMin AM\tMax AM\tAvg AM\tSD AM\tSource\n')
                    
                if np.max(result['visibility']) == 0:
                    print('No valid telluric standards, check airmass constriants')
                    f.write('No valid telluric standards, check airmass constriants\n')
                else:
                    for k in result['tel_ind'][0:5]:
                        if ao_type[i] == 'AO':
                            if tel_sample_rpmag[indx][k] <= 11.0:
                                rec_ao_mode = 'NGS'
                            else:
                                rec_ao_mode = 'LGS'
                        else:
                            rec_ao_mode = 'noAO'
                            
                        if np.isfinite(result['min_delta_am'][k]):
                            print('{:<30}\t{:.8f}\t{:.8f}\t{}\t{:.1f}\t{:.1f}\t{:.0f}\t{:.0f}\t{:.1f}\t{}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{}'.format(tel_sample_name[indx][k],
                                  tel_sample_ra[indx][k] * rd,
                                  tel_sample_dec[indx][k] * rd,
                                  band,
                                  tel_sample_mag[indx][k],
                                  dit[indx][k],
                                  snr[indx][k],
                                  max_val[indx][k],
                                  tel_sample_rpmag[indx][k],
                                  rec_ao_mode,
                                  result['min_delta_am'][k],
                                  result['max_delta_am'][k],
                                  result['avg_delta_am'][k],
                                  result['std_delta_am'][k],
                                  sample))
                            f.write('{:<30}\t{:.8f}\t{:.8f}\t{}\t{:.1f}\t{:.1f}\t{:.0f}\t{:.0f}\t{:.1f}\t{}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{}\n'.format(tel_sample_name[indx][k],
                                  tel_sample_ra[indx][k] * rd,
                                  tel_sample_dec[indx][k] * rd,
                                  band,
                                  tel_sample_mag[indx][k],
                                  dit[indx][k],
                                  snr[indx][k],
                                  max_val[indx][k],
                                  tel_sample_rpmag[indx][k],
                                  rec_ao_mode,
                                  result['min_delta_am'][k],
                                  result['max_delta_am'][k],
                                  result['avg_delta_am'][k],
                                  result['std_delta_am'][k],
                                  sample))

                for k in result['tel_ind'][0:5]:
                    if np.isfinite(result['min_delta_am'][k]):
                        _ = ax[j].plot(result['ha_start']*rd/15., result['delta_am'][:,k], label=tel_sample_name[indx][k])
                _ = ax[j].axhline(0, color='gray', lw=0.75)
                _ = ax[j].axhline(0.1, color='gray', lw=0.75, ls='--')
                _ = ax[j].axhline(0.2, color='gray', lw=0.75, ls=':')

                _ = ax[j].legend(loc='center left', bbox_to_anchor=(1,0.5))
                _ = ax[j].annotate(sample, xy=(0.5,0.9), xycoords='axes fraction', ha='center', va='center')

            _ = ax[1].set_ylabel(r'$\Delta$ airmass')
            _ = ax[-1].set_xlabel('Hour Angle at OB start (hrs)')
            fig.subplots_adjust(hspace=0.02, wspace=0.02)
            fig.savefig('output-AM-{}.png'.format(target_name[i].replace(' ','_')), dpi=300, bbox_inches='tight')


    return 0

def load_etc_prediction():

    # Exposure time for different settings - AO
    etc_ao = ascii.read('etc-ao.txt')

    dit_ao = {'25mas': {}, '100mas': {}, '250mas': {}}
    snr_ao = {'25mas': {}, '100mas': {}, '250mas': {}}
    max_ao = {'25mas': {}, '100mas': {}, '250mas': {}}

    for spxw in ('25mas', '100mas', '250mas'):
        for spgw in ('J_low', 'J_short', 'J_middle', 'J_long',
                 'H_low', 'H_short', 'H_middle', 'H_long',
                 'K_low', 'K_short', 'K_middle', 'K_long'):
            
            indx = np.where((etc_ao['spxw'].data == spxw) & \
                            (etc_ao['spgw'].data == spgw))[0]
            
            dit_ao[spxw][spgw] = interp.interp1d(etc_ao['mag'].data[indx], etc_ao['dit'].data[indx], kind='linear',
                                                bounds_error=None, fill_value=np.nan)
            snr_ao[spxw][spgw] = interp.interp1d(etc_ao['mag'].data[indx], etc_ao['snr'].data[indx], kind='linear',
                                                bounds_error=None, fill_value=np.nan)
            max_ao[spxw][spgw] = interp.interp1d(etc_ao['mag'].data[indx], etc_ao['max'].data[indx], kind='linear',
                                                bounds_error=None, fill_value=np.nan)

    # Exposure time for different settings - noAO
    etc_noao = ascii.read('etc-noao.txt')

    dit_noao = {'25mas': {}, '100mas': {}, '250mas': {}}
    snr_noao = {'25mas': {}, '100mas': {}, '250mas': {}}
    max_noao = {'25mas': {}, '100mas': {}, '250mas': {}}

    for spxw in ('25mas', '100mas', '250mas'):
        for spgw in ('J_low', 'J_short', 'J_middle', 'J_long',
                 'H_low', 'H_short', 'H_middle', 'H_long',
                 'K_low', 'K_short', 'K_middle', 'K_long'):
            
            indx = np.where((etc_noao['spxw'].data == spxw) & \
                            (etc_noao['spgw'].data == spgw))[0]
            
            dit_noao[spxw][spgw] = interp.interp1d(etc_noao['mag'].data[indx], etc_noao['dit'].data[indx], kind='linear',
                                                bounds_error=None, fill_value=np.nan)
            snr_noao[spxw][spgw] = interp.interp1d(etc_noao['mag'].data[indx], etc_noao['snr'].data[indx], kind='linear',
                                                bounds_error=None, fill_value=np.nan)
            max_noao[spxw][spgw] = interp.interp1d(etc_noao['mag'].data[indx], etc_noao['max'].data[indx], kind='linear',
                                                bounds_error=None, fill_value=np.nan)

    return dit_ao, snr_ao, max_ao, dit_noao, snr_noao, max_noao


def pick_telluric(ra, dec, tel_ra, tel_dec, ob_duration,
                  after=True, max_am_OB=2.5, max_am_tel=2.5,
                  method='avg'):
    '''
    ra - ra [radians] of target
    dec - dec [radians] of target
    tel_ra - ra [radians] of telluric stars
    tel_dec - dec [radians] of telluric stars
    ob_duration - the duration of the OB in minutes
    after - boolean, if true the telluric will be taken after
    max_am_OB - airmass constraint used in OB
    max_am - maximum airmass to consider for the telluric
    method - 'avg' - minimum average airmass difference,
             'min' - minimum minimum airmass difference,
             'max' - minimum maximium airmass difference,
             'std' - minimum standard deviation of airmass difference.
    '''
    
    dr = np.pi/180.0
    rd = 180.0/np.pi

    paranal = c.EarthLocation.of_site('paranal')
    lat = paranal.lat.radian

    # if after == True: Time between end of science OB and middle of standard OB
    # if after == False: Time between middle of standard and start of science OB
    dt = (5.0/60 * 15.0 * dr) # 5 minutes

    n_tel = len(tel_ra)
    
    # Minimum altitude for target and telluric
    target_min_alt = (np.pi/2.0) - np.arccos(1.0/max_am_OB)
    telluric_min_alt = (np.pi/2.0) - np.arccos(1.0/max_am_tel)
    
    # Grid of hour angles to attempt
    ha_start = np.arange(-8.0, 8.01, 5.0/60.0) * 15.0 * dr
    ha_end = ha_start + ((ob_duration/60.0) * 15.0 * dr)
    n = len(ha_start)
    
    # Altitude of target at start end end of OB
    alt_start = np.arcsin(math.sin(lat)*math.sin(dec) + math.cos(lat)*math.cos(dec)*np.cos(ha_start))
    alt_end = np.arcsin(math.sin(lat)*math.sin(dec) + math.cos(lat)*math.cos(dec)*np.cos(ha_end))
    
    # Array of delta airmass between target and telluric
    delta_am = np.full((n, n_tel), np.nan)
    visibility = np.full(n_tel, 0)
    target_visibility = 0
    
    # Loop over each HA attempted
    for i in range(0, n):
        if (alt_start[i] > target_min_alt) and (alt_end[i] > target_min_alt):
            
            # Average airmass over OB
            n_has = 50
            target_has = np.linspace(ha_start[i], ha_end[i], n_has)
            target_alts = np.arcsin(math.sin(lat)*math.sin(dec) + math.cos(lat)*math.cos(dec)*np.cos(target_has))
            target_ams = 1.0/np.cos((np.pi/2.0) - target_alts)
            target_avg_am = np.mean(target_ams)
            
            # LST for tellurics
            if after:
                lst_telluric = ha_end[i] + ra + dt
            else:
                lst_telluric = ha_start[i] + ra - dt
            
            # HA for tellurics
            tel_ha = tel_ra - lst_telluric
   
            # Altitude, airmass for tellurics
            tel_alt = np.arcsin(math.sin(lat)*np.sin(tel_dec) + \
                                 math.cos(lat)*np.cos(tel_dec)*np.cos(tel_ha))
            tel_am = 1.0/np.cos((np.pi/2.0) - tel_alt)
            
            # Flag those with altitudes below the minimum allowed
            # +inf here as we don't want the telluric to be considered at all
            tel_am[np.where(tel_alt <= telluric_min_alt)] = np.inf
            
            # Save the absolute difference in the airmass values
            delta_am[i] = np.abs(tel_am - target_avg_am)
            
            # Add one to the target visibility
            target_visibility += 1
            
    # Exclude tellurics that are visible for less than 90% of the time the target is
    visibility = np.sum(np.isfinite(delta_am), axis=0)
    delta_am[:, np.where(visibility < (0.9*target_visibility))[0]] = np.inf 
    
    # Exclude tellurics with large differences in declination
    delta_am[:, np.where(np.abs(tel_dec - dec) > (50.0*dr))[0]] = np.inf
        
    # Indices for a list in order of increasing maximum AM difference. Best tellurics at the start
    std_delta_am = np.full(n_tel, np.inf)
    if not np.all(np.isinf(delta_am) | np.isnan(delta_am)):
        min_delta_am = np.nanmin(delta_am, axis=0)
        max_delta_am = np.nanmax(delta_am, axis=0)
        avg_delta_am = np.nanmean(delta_am, axis=0)
        for j in range(0, n_tel):
            std_indx = np.where(np.isfinite(delta_am[:, j]))[0]
            if len(std_indx) > 1:
                std_delta_am[j] = np.nanstd(delta_am[:, j][std_indx])
    else:
        min_delta_am, max_delta_am, avg_delta_am, std_delta_am = np.full((4, n_tel), np.inf)

    if method == 'min': tel_ind = np.argsort(min_delta_am)
    if method == 'max': tel_ind = np.argsort(max_delta_am)
    if method == 'avg': tel_ind = np.argsort(avg_delta_am)
    if method == 'std': tel_ind = np.argsort(std_delta_am)
    
    return {'tel_ind': tel_ind,
            'delta_am': delta_am,
            'min_delta_am': min_delta_am,
            'max_delta_am': max_delta_am,
            'avg_delta_am': avg_delta_am,
            'std_delta_am': std_delta_am,
            'ha_start': ha_start,
            'visibility': visibility}

if __name__ == '__main__':
    main()
