"""
 Module  : dburi.py
 License : BSD License (see LICENSE.txt)

This module provides a series of utility classes and functions to return a 
database connection from a URI.

These URIs are of the form;
      'mysql://username[:password]@host[:port]/database name'
      'sqlite://path/to/db/file'
      'sqlite:/C|/path/to/db/file' - On MS Windows
      'sqlite:/:memory:' - For an in memory database
      'oracle://username:password@tns entry'

This module is inspired by (and somewhat borrows from) SQLObject's dbconnection.py, I've just purposely not included a lot of the baggage from that particular module.

To do;
 - Add support for PostgresSQL
 - Add ODBC support via pyodbc - http://pyodbc.sourceforge.net/
"""
__version__ = (0, 1, 2)
__date__ = (2006, 9, 20)
__author__ = "Andy Todd <andy47@halfcooked.com>"

debug = False

# I may replace this with more generic logging
from utilities.Log import get_log
log = get_log()

class Connection(object):
    pass

class MySqlConnection(Connection):
    def __init__(self, connection_string):
        try:
            import MySQLdb as db
        except ImportError:
            raise ImportError, "Can't connect to MySQL as db-api module not present"

        username, password, host, port, dbName = self.parse_uri(connection_string)
        self.connection = db.connect(user=username or '', passwd=password or '', host=host or 'localhost', port=port or 0, db=dbName or '')

    def parse_uri(self, connection_string):
        "Turn the connection_string into a series of parameters to the connect method"
        # Strip the leading '/'
        if connection_string.startswith('/'):
            connection_string = connection_string[1:]
        if connection_string.find('@') != -1:
            # Split into the username (and password) and the rest
            username, rest = connection_string.split('@')
            if username.find(':') != -1:
                username, password = username.split(':')
            # Take the rest and split into its host, port and db name parts
            if rest.find('/') != -1:
                host, dbName = rest.split('/')
            else:
                host = rest
                dbName = ''
            if host.find(':') != -1:
                host, port = host.split(':')
                try:
                    port = int(port)
                except ValueError:
                    raise ValueError, "port must be integer, got '%s' instead" % port
                if not (1 <= port <= 65535):
                    raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
            else:
                port = None
        else:
            raise ValueError, "MySqlConnection passed invalid connection_string"
        return username, password, host, port, dbName

class SqliteConnection(Connection):
    def __init__(self, connection_string):
        try:
            from pysqlite2 import dbapi2 as db
        except ImportError:
            try:
                from sqlite3 import dbapi2 as db # For Python 2.5 and above
            except ImportError:
                raise ImportError, "Can't connect to sqlite as db-api module not present"
        # If the path has a | character we replace it with a :
        if connection_string.find('|') != -1:
            connection_string.replace('|', ':')
        log.debug(connection_string)
        self.connection = db.connect(connection_string)

class OracleConnection(Connection):
    def __init__(self, connection_string):
        try:
            import cx_Oracle as db
        except ImportError:
            import dcoracle2 as db
        # Remove the leading / from the connection string 
        if connection_string.startswith('/'):
            connection_string = connection_string[1:]
        # replace the : between the username and password with a /
        if connection_string.find(':') != -1:
            connection_string = connection_string.replace(':', '/')
        # Connect to the database
        log.debug(connection_string)
        self.connection = db.connect(connection_string)

def get_connection(uri):
    """Get and return a database connection based on the uri
    
    The uri scheme is blatantly ripped off from SQLObject. The general form of 
    these uris is;
      'mysql://username[:password]@host[:port]/database name'
      'sqlite:/path/to/db/file'
      'oracle://username:password@tns entry'
    """
    helpers = { 'mysql': MySqlConnection,
                'sqlite': SqliteConnection,
                'oracle': OracleConnection }
    scheme, connection_string = uri.split(':/')
    connection = helpers[scheme](connection_string)
    return connection.connection
