#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import copy
import enum
import hashlib
import json
import os
import platform
import subprocess
import sys
from datetime import datetime
from shutil import which

try:
    from lxml import etree
    import osmium
    import requests
    from osmium.osm import TagList, OSMObject
    from tqdm import tqdm
except ImportError:
    raise ImportError("You are missing some Python modules. To install libs, run:\n\tpip install -r requirements.txt\n")

config_dir = "config/"
# data_dir = "/mnt/hgfs/geofabrik/"
data_dir = ""

# definitely unfixable
interesting_unfixable_objects = dict()

# probably auto fixable, but unsure
interesting_fixable_objects = dict()

# used to store restored tags
saved_states = dict()

translations = (
    "<ро", "<ра", "<di", "<rů", "<fo", "<un", "<δι", "<er", "<kü", "<pe", "<mi", "<差異", "<ខុ", "<차이", "<įv", "<वे",
    "<ym", "<ve", "<ró", "<rô", "<ol", "<рі", "<kh", "<不同",
)


def py_version_check():
    print("Your are running Python %s.%s.%s" % (sys.version_info.major, sys.version_info.minor, sys.version_info.micro))
    if sys.version_info < (3, 7):
        raise RuntimeError("Python < 3.7 is not supported, please upgrade and try again")


@enum.unique
class MsgLevel(enum.Enum):
    INFO = enum.auto(),
    MAJOR = enum.auto(),
    MAJOR_STEP = enum.auto(),


class Print:
    class bcolors:
        HEADER = '\033[95m'
        OKBLUE = '\033[94m'
        OKCYAN = '\033[96m'
        OKGREEN = '\033[92m'
        WARNING = '\033[93m'
        FAIL = '\033[91m'
        ENDC = '\033[0m'
        BOLD = '\033[1m'
        UNDERLINE = '\033[4m'

    counter = 0

    @classmethod
    def status(cls, msg: str, msg_level: MsgLevel = MsgLevel.INFO):
        if msg_level == MsgLevel.INFO:
            print(f"\n# {msg}")
        elif msg_level == MsgLevel.MAJOR:
            print(f"\n{cls.bcolors.OKBLUE}## {msg.capitalize()}{cls.bcolors.ENDC}")
        else:
            cls.counter += 1
            print(f"{'-' * 80}"
                  f"\n{cls.bcolors.HEADER}{cls.bcolors.BOLD}### {msg.upper()} (step {cls.counter}){cls.bcolors.ENDC}")


# https://stackoverflow.com/a/65655346
def major_step(msg):
    def decorator_function(original_function):
        def wrapper_function(*argss, **kwargs):
            Print.status(msg, MsgLevel.MAJOR_STEP)
            result = original_function(*argss, **kwargs)
            Print.status("OK", MsgLevel.INFO)
            return result

        return wrapper_function

    return decorator_function


class Downloader:
    @major_step("Download latest data files...")
    def __init__(self, iso3166: str, update: bool):
        self._iso3166 = iso3166
        self._update = update
        self._country = None

        # filenames
        self._latest = None
        self._history = None

        # JSON index of all downloads that Geofabrik offers, see https://download.geofabrik.de/technical.html
        self._geofabrik_index = "https://download.geofabrik.de/index-v1-nogeom.json"

        self.download_latest_pbf()

    @property
    def latest(self) -> str:
        return self._latest

    @property
    def history(self) -> str:
        return self._history

    @property
    def country_name(self):
        return self._country

    @property
    def country_code(self):
        return self._iso3166

    @staticmethod
    def get_cookie():
        filename = config_dir + "geofabrik_cookie.json"
        try:
            with open(filename, "r") as f:
                return json.loads(f.read())
        except FileNotFoundError:
            raise FileNotFoundError(
                f"Create a {filename} with the Geofabrik auth cookie. "
                f"The file should only contain the gf_download_oauth cookie in JSON object format:\n"
                f"{{\n"
                f"  \"gf_download_oauth\": \"login|2018-04-12|...\"\n"
                f"}}")

    @staticmethod
    def _modification_date(path_to_file):
        if platform.system() == 'Windows':
            return os.path.getmtime(path_to_file)
        else:
            stat = os.stat(path_to_file)
            return stat.st_mtime

    def _get_download_urls(self):
        """Download Geofabrik endpoint paths."""

        gf_paths = config_dir + "geofabrik_paths.json"

        # cache path in the filesystem
        if not os.path.isfile(gf_paths) or (
                datetime.now() - datetime.fromtimestamp(self._modification_date(gf_paths))).days > 7:
            geofabrik_index = requests.get(self._geofabrik_index)
            geofabrik_index.raise_for_status()
            with open(gf_paths, "w") as f:
                f.write(geofabrik_index.text)

        with open(gf_paths, "r") as f:
            data = json.loads(f.read())

        for entry in data["features"]:
            if entry.get("properties").get("iso3166-1:alpha2") and self.country_code in entry.get("properties").get(
                    "iso3166-1:alpha2"):
                self._country = entry["properties"]["name"]
                return {"normal": entry["properties"]["urls"]["pbf"],
                        "history": entry["properties"]["urls"]["history"]}

        raise RuntimeError("Give valid iso3166-1:alpha2 code!")

    def _verify_download(self, url, file):
        with requests.get(url + ".md5", cookies=self.get_cookie()) as r:
            Print.status(f"Verifying {file}...")
            server_md5 = r.text.split()[0]
            file_md5 = hashlib.md5(open(file, "rb").read()).hexdigest()

            if server_md5 == file_md5:
                Print.status(f"File valid")
            else:
                Print.status(f"File INVALID, aborting.")
                sys.exit(1)

    def download_latest_pbf(self):
        def _download(download_url, fname):
            resp = requests.get(download_url, cookies=self.get_cookie(), stream=True)
            resp.raise_for_status()

            total = int(resp.headers.get('content-length', 0))

            # https://stackoverflow.com/a/62113293
            with open(fname, 'wb') as file, tqdm(
                    desc=fname,
                    total=total,
                    unit='iB',
                    unit_scale=True,
                    unit_divisor=1024,
            ) as bar:
                for data in resp.iter_content(chunk_size=1024):
                    size = file.write(data)
                    bar.update(size)

            Print.status(f"{fname} download finished!")

        urls = self._get_download_urls()

        global data_dir
        data_dir = data_dir + self.country_name + "/"
        try:
            os.makedirs(data_dir)
        except FileExistsError:
            pass

        for url in urls.values():
            filename = data_dir + url.rsplit('/', 1)[-1]

            if not os.path.isfile(filename) or self._update:
                _download(url, filename)

            self._verify_download(url, filename)

            # update filenames
            if ".osh" in filename:
                self._history = filename
            else:
                self._latest = filename

    @staticmethod
    def print_header(path):
        f = osmium.io.Reader(path, osmium.osm.osm_entity_bits.NOTHING)

        header = f.header()

        Print.status(f"FILE INFO ({path})")
        print("Bbox:                       ", header.box())
        print("History file:               ", header.has_multiple_object_versions)
        print("Replication base URL:       ", header.get("osmosis_replication_base_url", "<none>"))
        print("Replication sequence number:", header.get("osmosis_replication_sequence_number", "<none>"))
        print("Replication timestamp:      ", header.get("osmosis_replication_timestamp", "<none>"))
        print("")


class IdCollector(osmium.SimpleHandler):
    def __init__(self):
        super(IdCollector, self).__init__()
        self.key = set()
        Print.status("Gathering OSM object IDs", MsgLevel.MAJOR_STEP)

    def apply_file(self, filename, locations=False, idx='flex_mem'):
        super().apply_file(data_dir + filename, locations, idx)

    def inspect_tags(self, osm_object: OSMObject, object_id: str):
        self.key.clear()
        for tag in osm_object.tags:
            if tag.k.startswith("<") and tag.k.startswith(translations):
                interesting_unfixable_objects.update({object_id: self.key})
                return

            if tag.v.startswith("<") and tag.v.startswith(translations):
                self.key.add(copy.deepcopy(tag.k))
                print(osm_object)

                if osm_object.version > 1:
                    interesting_fixable_objects.update({object_id: self.key})
                else:
                    interesting_unfixable_objects.update({object_id: self.key})
                return

    def node(self, n):
        self.inspect_tags(n, "n" + str(n.id))

    def way(self, w):
        self.inspect_tags(w, "w" + str(w.id))

    def relation(self, r):
        self.inspect_tags(r, "r" + str(r.id))


class HistoryHandler(osmium.SimpleHandler):
    def __init__(self):
        super(HistoryHandler, self).__init__()
        self.last_known_good = dict()
        self.prev_id = 0

    def apply_file(self, filename, locations=False, idx='flex_mem'):
        super().apply_file(data_dir + filename, locations, idx)

    def inspect_tags(self, osm_object: OSMObject, object_id: str):
        """Save objects' last correct value in global variable."""

        if object_id in interesting_fixable_objects.keys():
            for tag in osm_object.tags:
                if tag.k in interesting_fixable_objects.get(object_id):
                    if not (tag.v.startswith("<") and tag.v.startswith(translations)):
                        # reading is linear, the objects are sorted, the last known is the one to save
                        self.last_known_good.clear()
                        self.last_known_good.update({tag.k: tag.v})
                    if object_id != self.prev_id:
                        print(f"Saved previous state for {object_id}: {self.last_known_good}")
                        saved_states.update({object_id: self.last_known_good})
        self.prev_id = object_id

    def node(self, n):
        self.inspect_tags(n, "n" + str(n.id))

    def way(self, w):
        self.inspect_tags(w, "w" + str(w.id))

    def relation(self, r):
        self.inspect_tags(r, "r" + str(r.id))


class RestoreHandler(osmium.SimpleHandler):

    @major_step("Restoring object values...")
    def __init__(self, writer):
        super(RestoreHandler, self).__init__()
        self.writer = writer

    def apply_file(self, filename, locations=False, idx='flex_mem'):
        super().apply_file(data_dir + filename, locations, idx)

    @staticmethod
    def restore(osm_object: OSMObject, object_id: str):
        if object_id in saved_states.keys():
            tags = []
            for tag in osm_object.tags:
                if tag.k in saved_states.get(object_id):
                    if tag.v.startswith("<") and tag.v.startswith(translations):
                        tags.append(*list(saved_states.get(object_id).items()))
                else:
                    tags.append((tag.k, tag.v))
            tags.append(("josmbug", "fixed"))
            return osm_object.replace(tags=tags)
        return osm_object

    def node(self, n):
        self.writer.add_node(self.restore(n, "n" + str(n.id)))

    def way(self, w):
        self.writer.add_way(self.restore(w, "w" + str(w.id)))

    def relation(self, r):
        self.writer.add_relation(self.restore(r, "r" + str(r.id)))


class OsmiumRunner:
    def __init__(self, overwrite=False):
        """
        Osmium command line tool runner (wrapper) class with predefined calls.
        :param overwrite: True if existing files should be overwritten
        """
        self._overwrite = overwrite
        self._osmium_tool_check()

    @property
    def overwrite(self):
        return self._overwrite

    @staticmethod
    def _osmium_tool_check():
        if not is_tool("osmium"):
            raise RuntimeError(
                "The osmium tool (https://osmcode.org/osmium-tool/) not installed, please install and try again")

    @staticmethod
    def show_file_info(file, extended: bool, crc: bool):
        """Show info about an OSM file.
        :param file: Input file
        :param extended: True if extended file information should be shown
        :param crc: True if CRC checksum should be calculated and printed
        """

        Print.status("READING FILE HEADER...", MsgLevel.INFO)

        if extended:
            e = "-e"
        else:
            e = ""
        if crc:
            c = "-c"
        else:
            c = ""

        subprocess.run(["osmium", "fileinfo", file, e, c])

    @major_step("filtering primitives...")
    def filter_tags(self, filter_file, input_file, output_file):
        if self.overwrite:
            overwrite = "--overwrite"
        else:
            overwrite = ""

        subprocess.run(
            ["osmium", "tags-filter", "--progress", "-e", config_dir + filter_file, input_file, "-o",
             data_dir + output_file, overwrite])

    @major_step("GETTING OBJECTS FROM HISTORY FILE...")
    def object_by_id(self, input_file, output_file, ids):
        if self.overwrite:
            overwrite = "--overwrite"
        else:
            overwrite = ""
        subprocess.run(
            ["osmium", "getid", input_file, "-o", data_dir + output_file, ids, "--with-history", overwrite])


@major_step("Adding JOSM specific XML tags, prettify")
def add_josm_specific_xml_tags():
    xml_parser = etree.XMLParser(remove_blank_text=True)
    tree = etree.parse(args.restored_file, xml_parser)
    root = tree.getroot()
    root.set("upload", "false")

    # for fixable in tree.xpath("//tag[@k='josmbug']/parent::*"):
    for marker in tree.xpath("//tag[@k='josmbug']"):
        marker.getparent().set("action", "modify")
        marker.getparent().remove(marker)

    tree.write(args.restored_file, encoding='utf-8', xml_declaration=True, pretty_print=True)


def is_tool(name) -> bool:
    """Check whether `name` is on PATH and marked as executable."""
    return which(name) is not None


def arg_parse():
    global parser, args
    parser = argparse.ArgumentParser(
        prog="""Python script to fix JOSM multiselect value override issue in
                the past 12+ years. See bug #21375 in JOSM issue tracker."""
    )

    parser.add_argument(
        "country",
        help="Country ISO3166 code. E.g. \"HU\" for Hungary, \"IE\" for Europe"
    )
    parser.add_argument(
        "-filtered",
        default="<countryname>-latest_filtered.osm.pbf",
        dest="filtered_file",
        help="Intermediate file used for osmosis filtering. Use .osm extension if you want examine the file manually."
    )
    parser.add_argument(
        "-collected",
        default="<countryname>-collected.osm.pbf",
        dest="collected_file",
        help="Intermediate file containing collected object history needed for restoration. "
             "Use .osm extension if you want examine the file manually."
    )
    parser.add_argument(
        "-restored",
        default="<countryname>-restored.osm",
        dest="restored_file",
        help="Intermediate file containing restored values. Use .osm extension!"
    )
    parser.add_argument(
        "--update",
        action="store_true",
        help="Update the downloaded files, requires internet connection"
    )

    args = parser.parse_args()


if __name__ == '__main__':
    py_version_check()
    arg_parse()

    # 1. Download the latest data files
    downloader = Downloader(iso3166=args.country, update=args.update)
    # replace default names
    args.filtered_file = args.filtered_file.replace("<countryname>", downloader.country_name.lower())
    args.collected_file = args.collected_file.replace("<countryname>", downloader.country_name.lower())
    args.restored_file = args.restored_file.replace("<countryname>", downloader.country_name.lower())

    downloader.print_header(downloader.history)
    downloader.print_header(downloader.latest)

    # 2. Search JOSM issues
    osmium_tool = OsmiumRunner()
    osmium_tool.filter_tags("tagfilter_prefix.txt", downloader.latest, args.filtered_file)

    # 3. Getting to know interesting objects
    i = IdCollector()
    i.apply_file(args.filtered_file, locations=True)

    if len(interesting_fixable_objects) == 0:
        Print.status("No JOSM issues found, exiting...", MsgLevel.MAJOR)
        sys.exit(0)

    osmium_tool.object_by_id(downloader.history, args.collected_file, " ".join(interesting_fixable_objects.keys()))

    # 4. Look up and store previous values
    h = HistoryHandler()
    h.apply_file(args.collected_file, locations=True)

    # 5. Restore the values
    try:
        os.remove(args.restored_file)
    except FileNotFoundError:
        pass

    writer = osmium.SimpleWriter(args.restored_file)
    a = RestoreHandler(writer)
    a.apply_file(args.filtered_file, locations=True)
    writer.close()

    # 6. Print some info about restoration
    if len(saved_states) > 0:
        Print.status(
            f"Found {len(saved_states)} objects with restorable values. "
            f"See the fixed objects in {args.restored_file}.", MsgLevel.MAJOR)
    unfixable_count = len(interesting_unfixable_objects) + len(interesting_fixable_objects) - len(saved_states)
    if unfixable_count > 0:
        objects1 = set(interesting_unfixable_objects.keys())
        objects2 = set(interesting_fixable_objects.keys())
        objects3 = set(saved_states.keys())

        Print.status(
            f"Found {unfixable_count} objects with no history, so they are unfixable by the script :(\n"
            f"they are {(objects1 | objects2) - objects3}", MsgLevel.MAJOR)

    # 7. add action=modify attributes, cleanup XML
    add_josm_specific_xml_tags()
