import ast
import datetime
import os
from pathlib import Path as P
from pathlib import PurePath as PP

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.patches import Rectangle
from readDataFile import read

# I got the inspiration for a lot of the SQL from the class notes (https://www.mat.ucsb.edu/~g.legrady/academic/courses/15w259/d/SQL_demos_.pdf) 'Volume and Multi-dimensions'
# The info on when lockouts occurred came from https://abc17news.com/news/2021/12/04/pro-sports-lockouts-and-strikes-fast-facts-2/


def main():
    fname = '/Volumes/GoogleDrive/My Drive/Courses/3rd year/MAT259/proj1/sports data.csv'
    data = pd.read_csv(fname)
    sports = ['football', 'basketball', 'hockey', 'baseball']
    fig, ax = plt.subplots(len(sports), sharex=True)
    lockouts = {
        'hockey': ['2012-09-15', '2013-01-12'],
        'basketball': ['2011-07-01', '2011-12-08'],
        'football': ['2011-03-12', '2011-08-04'],
        'baseball': ['2021-12-02', '2022-01-01']
    }

    for i, s in enumerate(sports):
        line = ax[i].plot(data[s], label=s.title())
        try:
            start = {
                'year': int(lockouts[s][0].split("-")[0]),
                'month': int(lockouts[s][0].split("-")[1])
            }
            end = {
                'year': int(lockouts[s][1].split("-")[0]),
                'month': int(lockouts[s][1].split("-")[1])
            }
            startIdx = (start['year'] - 2006) * 12 + start['month'] - 1
            endIdx = (end['year'] - 2006) * 12 + end['month'] - 1
            c = line[0].get_color()
            if endIdx - startIdx == 1:
                startIdx -= 1
            # ax[i].add_patch(
            #     Rectangle((startIdx, np.min(data[s][startIdx:endIdx+1])*0.75),
            #               endIdx - startIdx,
            #               1/(0.75)*(np.max(data[s][startIdx:endIdx+1]) - np.min(data[s][startIdx:endIdx+1])*0.75),
            #               facecolor=c, alpha=0.5))
            ax[i].axvspan(startIdx, endIdx, alpha=0.5, facecolor='green')
        except KeyError:
            pass

        ax[i].spines['top'].set_visible(False)
        ax[i].spines['right'].set_visible(False)
        ax[i].text(193, 0.5*np.max(data[sports[i]]), sports[i].title(), rotation=270, va='center')
        # fig.text(0.04, 0.5, 'Checkouts (books and dvds)', va='center', rotation='vertical')
        # ax[-1].set_xlabel('Year')
        fig.supxlabel('Year')
        fig.supylabel('Checkouts (books and dvds)')
        fig.suptitle('Monthly checkouts of sports media\n(lockouts highlighted in green)')
        ax[-1].set_xticks(range(0, len(data['year(cout)']), 12))
        ticks = [
            int(ii) + 1
            for ii in sorted(list(set(data['year(cout)'])), key=lambda x: int(x))
        ]
        ax[i].set_xticklabels(ticks, rotation=70)

    for i, _ in enumerate(
            sorted(list(set(data['year(cout)'])), key=lambda x: int(x))):

        if i % 2 != 0:
            for ii, _ in enumerate(sports):
                ax[ii].axvspan(12 * i, 12 * (i + 1), facecolor='gray', alpha=0.2)

    # plt.legend()
    plt.tight_layout()

    plt.savefig(P(fname).parent.joinpath('output.png'), dpi=300)

if __name__ == "__main__":
    main()
    plt.show()
