import json
from pathlib import Path
import numpy as np
import re
from . import RawFileReader
from .enums import DetectorType

class RawFile:
    """
    Generic Raw File reader. Picks up settings from .json master file
    Currently supports: Moench03 =)
    """
    def __init__(self, fname, header = False):
        self.findex = 0
        self.fname = fname
        self.json = False #Master file data comes from json
        fname = Path(fname)
        if fname.suffix == '.json':
            with open(fname) as f:
                self.master = json.load(f)
                self.json = True
        elif fname.suffix == '.raw':
            with open(fname) as f:
                self._load_raw_master(f)
        else:
            raise ValueError(f'{fname.suffix} not supported')

        #Figure out which file to open
        if self.master['Detector Type'] == 'Moench' and self.master['Analog Samples'] == 5000:
            #TODO! pass settings to reader
            self._parse_fname()
            self.reader = RawFileReader(self.data_fname(0,0), DetectorType.MOENCH_03, header = header)
        elif self.master['Detector Type'] == 'ChipTestBoard' and self.master['Analog Samples'] == 5000:
            self._parse_fname()

            #Do we read analog or analog+digital
            if self.master['Digital Flag'] == 1:
                dt = DetectorType.MOENCH_04_AD
            else: 
                dt = DetectorType.MOENCH_04_AD
            self.reader = RawFileReader(self.data_fname(0,0), dt, header = header)
        else:
            raise ValueError('unsupported file')

    def _parse_fname(self):
        try:
            base, _, run_id = self.fname.stem.rsplit("_", 2)
            self.base = self.fname.parent / base
            self.run_id = int(run_id)
        except:
            raise ValueError(f"Could not parse master file name: {self.fname}")

    def _load_raw_master(self, f):
        self.master = {}
        lines = f.readlines()
        it = iter(lines)

        for line in it:
            if line.startswith("#Frame"):
                break
            if line == "\n":
                continue
            if line.startswith("Scan Parameters"):
                while not line.endswith("]\n"):
                    line += next(it)

            field, value = line.split(":", 1)
            self.master[field.strip(" ")] = value.strip(" \n")

        frame_header = {}
        for line in it:
            field, value = line.split(":", 1)
            frame_header[field.strip()] = value.strip(" \n")

        self.master["Frame Header"] = frame_header
        self.master["Version"] = float(self.master["Version"])
        self._parse_values()


    def _parse_values(self):
        int_fields = set(
            (
                "Analog Samples",
                "Analog Flag",
                "Digital Flag",
                "Digital Samples",
                "Max Frames Per File",
                "Image Size",
                "Frame Padding",
                "Total Frames",
                "Dynamic Range",
                "Ten Giga",
                "Quad",
                "Number of Lines read out",
                "Number of UDP Interfaces"
            )
        )
        time_fields = set((
            "Exptime", 
            "Exptime1", 
            "Exptime2",
            "Exptime3",
            "GateDelay1",
            "GateDelay2",
            "GateDelay3",
            "SubExptime",#Eiger
            "SubPeriod", #Eiger
            "Period"

        ))

        #some fields might not exist for all detectors 
        #hence using intersection
        for field in time_fields.intersection(self.master.keys()):
            self.master[field] = self.to_nanoseconds(self.master[field])

        #Parse bothx .json and .raw master files
        if self.json:
            self.master['Image Size'] = self.master["Image Size in bytes"]
            self.master['Pixels'] = (self.master['Pixels']['x'], self.master['Pixels']['y'])
            self.master['nmod'] = int(self.master['Geometry']['x']*self.master['Geometry']['y'] )#ports not modules
            if self.master['Detector Type'] == 'Eiger':
                self.master['nmod'] = self.master['nmod'] // 2
        else:
            for field in int_fields.intersection(self.master.keys()):
                self.master[field] = int(self.master[field].split()[0])
            self.master["Pixels"] = tuple(
                int(i) for i in self.master["Pixels"].strip("[]").split(",")
            )

        if "Rate Corrections" in self.master:
            self.master["Rate Corrections"] = (
                self.master["Rate Corrections"].strip("[]").split(",")
            )
            n = len(self.master["Rate Corrections"])
            assert (
                self.master["nmod"] == n
            ), f'nmod from Rate Corrections {n} differs from nmod {self.master["nmod"]}'

        #Parse threshold for Mythen3 (if needed)
        if "Threshold Energies" in self.master.keys():
            th = self.master["Threshold Energies"]
            if isinstance(th, str):
                th = [int(i) for i in th.strip('[]').split(',')]
                self.master["Threshold Energies"] = th

    @staticmethod
    def to_nanoseconds(t):
        nanoseconds = {"s": 1000 * 1000 * 1000, "ms": 1000 * 1000, "us": 1000, "ns": 1}
        try:
            value = re.match(r"(\d+(?:\.\d+)?)", t).group()
            unit = t[len(value) :]
            value = int(float(value) * nanoseconds[unit])
            value = np.timedelta64(value, 'ns')
        except:
            raise ValueError(f"Could not convert: {t} to nanoseconds")
        return value

    def data_fname(self, i, findex=0):
            return Path(f"{self.base}_d{i}_f{findex}_{self.run_id}.raw")


    def read(self, *args):
        return self.reader.read(*args)
    

    
    # Support iteration
    def __iter__(self):
        return self
    
    def __next__(self):
        res = self.reader.read()
        if res.shape[0] == 0:
            raise StopIteration
    
    # Support with statement
    def __enter__(self):
        return self
    
    def __exit__(self, exception_type, exception_value, traceback):
        pass