# -*- coding: utf-8 -*-
"""This module defines functions for handling PDB sequence clusters."""

import os.path

from numpy import array, abs

from prody import LOGGER, SETTINGS, getPackagePath
from prody.utilities import openFile, openURL, pystr, isListLike

from numbers import Integral

__all__ = ['fetchPDBClusters', 'loadPDBClusters', 'listPDBCluster']

PDB_CLUSTERS = {30: None, 40: None, 50: None, 70: None,
                90: None, 95: None, 100: None}
PDB_CLUSTERS_UPDATE_WARNING = True
PDB_CLUSTERS_SQIDS = array(list(PDB_CLUSTERS))
PDB_CLUSTERS_SQIDS.sort()
PDB_CLUSTERS_SQID_STR = ', '.join([str(key) for key in PDB_CLUSTERS_SQIDS])

def loadPDBClusters(sqid=None):
    """Load previously fetched PDB sequence clusters from disk to memory."""

    PDB_CLUSTERS_PATH = os.path.join(getPackagePath(), 'pdbclusters')
    if sqid is None:
        sqid_list = list(PDB_CLUSTERS)
        LOGGER.info('Loading all PDB sequence clusters.')
    else:
        assert isinstance(sqid, Integral), 'sqid must be an integer'
        if sqid not in PDB_CLUSTERS:
            raise ValueError('PDB cluster data is not available for sequence '
                             'identity {0}%, try one of {1}'
                             .format(sqid, PDB_CLUSTERS_SQID_STR))
        LOGGER.info('Loading PDB sequence clusters for sequence identity '
                    '{0}.'.format(sqid))
        sqid_list = [sqid]
    global PDB_CLUSTERS_UPDATE_WARNING
    for sqid in sqid_list:
        filename = os.path.join(PDB_CLUSTERS_PATH,
                                'bc-{0}.out.gz'.format(sqid))
        if not os.path.isfile(filename):
            fetchPDBClusters(sqid)

        if PDB_CLUSTERS_UPDATE_WARNING:
            import time
            diff = (time.time() - os.path.getmtime(filename)) / 604800.
            if diff > 1.:
                LOGGER.warning('PDB sequence clusters are {0:.1f} week(s) old,'
                               ' call `fetchPDBClusters` to receive updates.'
                               .format(diff))
                PDB_CLUSTERS_UPDATE_WARNING = False
        inp = openFile(filename)
        clusters_str = pystr(inp.read())

        clusters = []
        for cluster_str in clusters_str.split('\n'):
            cluster_str = cluster_str.strip()
            if len(cluster_str):
                cluster = [tuple(item.split('_')) for item in cluster_str.split()]
                clusters.append(cluster)

        PDB_CLUSTERS[sqid] = clusters
        inp.close()

    if sqid is None:
        return PDB_CLUSTERS
    else:
        return clusters


def listPDBCluster(pdb, ch, sqid=95):
    """Returns the PDB sequence cluster that contains chain *ch* in structure
    *pdb* for sequence identity level *sqid*.  PDB sequence cluster will be
    returned in as a list of tuples, e.g. ``[('1XXX', 'A'), ]``.  Note that
    PDB clusters individual chains, so the same PDB identifier may appear
    twice in the same cluster if the corresponding chain is present in the
    structure twice.

    Before this function is used, :func:`fetchPDBClusters` needs to be called.
    This function will load the PDB sequence clusters for *sqid* automatically
    using :func:`loadPDBClusters`."""

    assert isinstance(pdb, str) and len(pdb) == 4, \
        'pdb must be 4 char long string'
    assert isinstance(ch, str) and len(ch) == 1, \
        'ch must be a one char long string'
    try:
        sqid = int(sqid)
    except TypeError:
        raise TypeError('sqid must be an integer')
    if not (30 <= sqid <= 100):
        raise ValueError('sqid must be between 30 and 100')
    sqid = PDB_CLUSTERS_SQIDS[abs(PDB_CLUSTERS_SQIDS-sqid).argmin()]
    clusters = PDB_CLUSTERS[sqid]
    if clusters is None:
        loadPDBClusters(sqid)
        clusters = PDB_CLUSTERS[sqid]
    pdb_ch = (pdb.upper(), ch)

    for cluster in clusters:
        if pdb_ch in cluster:
            return cluster
    return 

def fetchPDBClusters(sqid=None):
    """Retrieve PDB sequence clusters.  PDB sequence clusters are results of
    the weekly clustering of protein chains in the PDB generated by blastclust.
    They are available at FTP site: https://cdn.rcsb.org/resources/sequence/clusters/

    This function will download about 10 Mb of data and save it after
    compressing in your home directory in :file:`.prody/pdbclusters`.
    Compressed files will be less than 4 Mb in size.  Cluster data can
    be loaded using :func:`loadPDBClusters` function and be accessed
    using :func:`listPDBCluster`."""

    if sqid is not None:
        if isListLike(sqid):
            for s in sqid:
                if s not in PDB_CLUSTERS:
                    raise ValueError('sqid must be one or more of ' + PDB_CLUSTERS_SQID_STR)
            keys = list(sqid)
        else:
            if sqid not in PDB_CLUSTERS:
                raise ValueError('sqid must be one or more of ' + PDB_CLUSTERS_SQID_STR)
            keys = [sqid]
    else:
        keys = list(PDB_CLUSTERS)

    PDB_CLUSTERS_PATH = os.path.join(getPackagePath(), 'pdbclusters')
    if not os.path.isdir(PDB_CLUSTERS_PATH):
        os.mkdir(PDB_CLUSTERS_PATH)
    LOGGER.progress('Downloading sequence clusters', len(keys),
                    '_prody_fetchPDBClusters')
    count = 0
    for i, x in enumerate(keys):
        filename = 'bc-{0}.out'.format(x)
        url = ('https://cdn.rcsb.org/resources/sequence/clusters/' + filename)
        try:
            inp = openURL(url)
        except IOError:
            LOGGER.warning('Clusters at {0}% sequence identity level could '
                           'not be downloaded.'.format(x))
            continue
        else:
            out = openFile(filename+'.gz', 'w', folder=PDB_CLUSTERS_PATH)
            out.write(inp.read())
            inp.close()
            out.close()
            count += 1
        LOGGER.update(i, label='_prody_fetchPDBClusters')
    LOGGER.finish()
    if len(keys) == count:
        LOGGER.info('All selected PDB clusters were downloaded successfully.')
    elif count == 0:
        LOGGER.warn('PDB clusters could not be downloaded.')
