#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
import os
import sys

def read_data(filename: str, col_ind: int) -> np.ndarray:
    """Open and read the data

    Read the given column of the provided file

    Args:
        filename (str): the name of the file to read
        col_ind (int): the index of the desired data in the file

    Returns:
        numpy.ndarray: the data at the given column index in the provided file

    """
    if not os.path.exists(filename):
        raise IOError("Requested file does not exist.")

    # tell user the file is being read
    print("reading data from file", filename)
    data = np.loadtxt(filename, usecols = (col_ind), unpack = True)
    return data


def plot_phi_lm(file: str, spin: str, rex: float, mp: str, yscale: str, savepath: str):
    """ Plot requested Phi multipole timeseries

    Read the data from the requested file and plot the desired multipole using either linear or logscale. Save the
    plot to the directory specified in savepath.

    Args:
        file (str): file with the multipole data
        spin (str): BH spin from the simulation you are considering; this should be a string as it gives the label
        rex (float): extraction radius, i.e. where on the grid you are extracting your data from
        mp (str): l, m for the multipole you want to plot, formatted as "lm"
        yscale (str): scale of yaxis; options are {'log', 'lin'}
        savepath str(): directory to save the figure in

    """
    font = { 'size' : 14 }
    rc('font', **font)
    fig, axs = plt.subplots(1, 1)
    fig.set_figheight(5)
    fig.set_figwidth(6)

    savename = os.path.join(savepath, f"{file.split('/')[-1][:-4]}.svg")

    times = read_data(file, 0)
    phi_lm = read_data(file, 1)
    if yscale == 'log':
        axs.semilogy(times - rex, abs(phi_lm * rex), label=r'$\chi = $' + spin)
    else:
        axs.plot(times-rex, phi_lm*rex, label = r'$\chi = $' + spin)
        plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
    plt.xlim(0, 200)
    axs.set_xticks([0,50,100,150,200])
    plt.legend(loc='upper right')
    plt.xlabel(r'$(t-r_{ex})/M$', size = 18)
    if yscale == 'log':
        plt.ylabel(fr'$r_{{ex}}|\Phi_{{{mp}}}|$', size=18)
    else:
        plt.ylabel(fr'$r_{{ex}}\Phi_{{{mp}}}$', size = 18)
    plt.subplots_adjust(left=0.18, right=0.95, bottom=0.15, top=0.95)
    plt.grid(True, linestyle = '--')
    plt.savefig(savename)


def plot_theta_lm(file, spin, rex, mp, yscale, savepath):
    """ Plot requested Theta multipole timeseries

    Read the data from the requested file and plot the desired multipole using either linear or logscale. Save the
    plot to the directory specified in savepath.

    Args:
        file (str): file with the multipole data
        spin (str): BH spin from the simulation you are considering; this should be a string as it gives the label
        rex (float): extraction radius, i.e. where on the grid you are extracting your data from
        mp (str): l, m for the multipole you want to plot, formatted as "lm"
        yscale (str): scale of yaxis; options are {'log', 'lin'}
        savepath str(): directory to save the figure in

    """
    font = { 'size' : 14 }
    rc('font', **font)
    fig, axs = plt.subplots(1, 1)
    fig.set_figheight(5)
    fig.set_figwidth(6)

    savename = os.path.join(savepath, f"{file.split('/')[-1][:-4]}.svg")

    times = read_data(file, 0)
    theta_lm = read_data(file, 1)
    if yscale == 'log':
        axs.semilogy(times - rex, abs(theta_lm * rex * rex), label=r'$\chi = $' + spin)
    else:
        axs.plot(times-rex, theta_lm*rex*rex, label = r'$\chi = $' + spin)
        plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
    plt.xlim(0, 200)
    axs.set_xticks([0,50,100,150,200])
    plt.legend(loc='upper right')
    plt.xlabel(r'$(t-r_{ex})/M$', size = 18)
    if yscale == 'log':
        plt.ylabel(fr'$r_{{ex}}^2|\Theta_{{{mp}}}|$', size=18)
    else:
        plt.ylabel(fr'$r_{{ex}}^2\Theta_{{{mp}}}$', size = 18)
    plt.subplots_adjust(left=0.18, right=0.95, bottom=0.15, top=0.95)
    plt.grid(True, linestyle = '--')
    plt.savefig(savename)

####################################################################
# Example to use plotting script
####################################################################

aSpt01_apt9_Theta10_r20 = None  # path to mp_Theta_l1_m0_r20.00.asc
aSpt01_apt9_Phi00_r20 = None  # path to mp_Phi_l0_m0_r20.00.asc
aSpt01_apt9_Theta30_r20 = None  # path to mp_Theta_l3_m0_r20.00.asc
aSpt01_apt9_Phi20_r20 = None  # path to mp_Phi_l2_m0_r20.00.asc'

if aSpt01_apt9_Theta10_r20 is None or aSpt01_apt9_Phi00_r20 is None or aSpt01_apt9_Theta30_r20 is None or aSpt01_apt9_Phi20_r20 is None:
  print("Please set paths to the data files in the source code", file=sys.stderr)
  sys.exit(1)

plot_theta_lm(aSpt01_apt9_Theta10_r20, r'$0.9$', 20.0, "10", 'lin', '')
plot_phi_lm(aSpt01_apt9_Phi00_r20,  r'$0.9$', 20.0, "00", 'lin', '')
plot_theta_lm(aSpt01_apt9_Theta30_r20, r'$0.9$', 20.0, "30", 'log', '')
plot_phi_lm(aSpt01_apt9_Phi20_r20,  r'$0.9$', 20.0, "20", 'log', '')
