import socket, sys, struct, exceptions, types class Error (exceptions.StandardError): pass class DatabaseError (Error): pass class InterfaceError (Error): pass apilevel = "2.0" threadsafety = 1 OID_BOOL=16 OID_INT8=20 OID_INT4=23 OID_INT2=21 OID_OID=26 OID_FLOAT4=700 OID_FLOAT8=701 OID_DATE=1082 OID_TIMESTAMP=1114 OID_TIME=1083 OID_TIMESTAMPTZ=1184 OID_TIMETZ=1266 def connect (database, user, password=None, host=None, port=5432): #sock = None err_msg = '' if host: for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM): af, socktype, proto, canonname, sa = res try: sock = socket.socket(af, socktype, proto) except socket.error, msg: sock = None err_msg = msg continue try: sock.connect(sa) except socket.error, msg: sock.close() err_msg = msg sock = None continue break else: # we're assuming UNIX socket here paths = ['/tmp/.s.PGSQL.5432', '/var/run/postgresql/.s.PGSQL.5432'] # FIXME: list of paths used on various distros if port != 5432: paths.append (port) for path in paths: try: sock = socket.socket (socket.AF_UNIX, socket.SOCK_STREAM, 0) sock.connect ((path)) except socket.error, msg: sock.close () err_msg = msg sock = None continue break if sock is None: raise DatabaseError, err_msg return Connection (sock, user, database, password) class Connection: def __init__ (self, sock, user, db, password): self.sock = sock self.__password = password self.__user = user self.__input = '' self.__output = '' self.__loggedin = 0 self.param = {} self.__buffer = [] self.__description = [] self.__notifications = [] self.__listens = [] self.handlers = {'R':self.AuthenticationRequest, 'K':self.BackendKeyData, 'C':self.CommandComplete, 'H':self.CopyOutResponse, 'G':self.CopyInResponse, 'd':self.CopyData, 'f':self.CopyFail, 'D': self.DataRow, 'E':self.ErrorResponse, 'N':self.NoticeResponse, 'A':self.NotificationResponse, 't':self.ParameterDescription, 'S':self.ParameterStatus, 'Z':self.ReadyForQuery, 'T':self.RowDescription} # send login message self.add_struct ("!L", 196608) self.add_str ("user") self.add_str (user) self.add_str ("database") self.add_str (db) self.add_str ('') self.send ('') self.getmessage ('Z') def send (self, message_type): if message_type: self.sock.send (struct.pack ("!cL", message_type, len (self.__output)+4)) else: self.sock.send (struct.pack ("!L", len (self.__output) + 4)) self.sock.send (self.__output) self.__output = "" def recv (self): self.__input = "" while len (self.__input) < 5: self.__input += self.sock.recv (5-len(self.__input)) self.message_type, msg_len = struct.unpack ("!cL", self.__input[0:5]) while len (self.__input) < msg_len + 1: self.__input += self.sock.recv (msg_len+1-len(self.__input)) self.__pos = 5 def add_struct (self, fmt, *args): self.__output += struct.pack (fmt, *args) def get_struct (self, fmt): length = struct.calcsize (fmt) if length+self.__pos > len (self.__input): return None ret = struct.unpack (fmt, self.__input[self.__pos:self.__pos+length]) self.__pos += length return ret def add_str (self, string): if type (string) == types.UnicodeType: string = string.encode ('latin-1', 'replace') self.__output += string self.__output += '\000' def get_str (self): strend = self.__pos while strend < len (self.__input) and self.__input[strend] != '\000': strend += 1 if strend == self.__pos: return "" else: ret = self.__input[self.__pos:strend] self.__pos = strend+1 return ret def get_raw (self, length): end = self.__pos + length ret = self.__input[self.__pos:end] self.__pos = end return ret def getmessage (self, endmessage): """ This is the state machine that repsonds to incoming messages """ self.message_type = 'XXX' while not (self.message_type in endmessage): self.recv () if self.handlers.has_key (self.message_type): self.handlers[self.message_type] () def AuthenticationRequest (self): mode, = self.get_struct ("!L") if mode == 0: self.__loggedin = 1 elif mode == 1: raise InterfaceError, "Kerberos V4 not supported'" elif mode == 2: raise InterfaceError, "Kerberos V5 not supported" elif mode == 3: # clear text password self.add_str (self.__password) self.send ('p') elif mode == 4: # crypt () password salt = self.get_struct ("2s") import crypt self.add_str (crypt.crypt (self.__password, salt)) self.send('p') elif mode == 5: # md5 password salt = self.get_struct ("4s") import md5 md5_1 = md5.new (self.__password + self.__user) md5_2 = md5.new (md5_1.hexdigest () + salt) self.add_str ("md5" + md5_2.hexdigest ()) self.send('p') else: raise InterfaceError, "unkown authentication mode %d" % mode def BackendKeyData (self): self.__backend_pid, self.__backend_key = self.get_struct ("!LL") def CommandComplete (self): self.__complete = self.get_str () def CopyData (self): pass def CopyDone (self): pass def CopyOutResponse (self): pass def CopyInResponse (self): pass def CopyFail (self): raise DatabaseError, self.get_str () def DataRow (self): cols, = self.get_struct ("!h") row = [] for i in range (0, cols): itemlen, = self.get_struct ("!l") if itemlen == -1: row.append (None) # map SQL NULL to Python None else: if self.__description[i]['typeOID'] == OID_INT2: if self.__description[i]['format'] == 1: row.append (self.get_struct ("!h")) else: row.append (int (self.get_raw (itemlen))) elif self.__description[i]['typeOID'] in [OID_INT4, OID_OID]: if self.__description[i]['format'] == 1: row.append (self.get_struct ("!l")) else: row.append (int (self.get_raw (itemlen))) elif self.__description[i]['typeOID'] == OID_INT8: if self.__description[i]['format'] == 1: row.append (self.get_struct ("!q")) else: row.append (int (self.get_raw (itemlen))) elif self.__description[i]['typeOID'] == OID_BOOL: if self.__description[i]['format'] == 1: row.append (self.get_struct ("B")) else: row.append (float (self.get_raw (itemlen))) elif self.__description[i]['typeOID'] == OID_FLOAT4: if self.__description[i]['format'] == 1: row.append (self.get_struct ("f")) else: row.append (self.get_raw (itemlen) == 't') elif self.__description[i]['typeOID'] == OID_FLOAT8: if self.__description[i]['format'] == 1: row.append (self.get_struct ("d")) else: row.append (float(self.get_raw (itemlen))) elif self.__description[i]['typeOID'] == OID_BOOL: if self.__description[i]['format'] == 1: row.append (self.get_struct ("B")) else: row.append (self.get_raw (itemlen) == 't') else: row.append (self.get_raw (itemlen)) self.__buffer.append (row) def ErrorResponse (self): fieldtype, = self.get_struct ("c") error = {} while fieldtype != '\000': error[fieldtype] = self.get_str () fieldtype, = self.get_struct ("c") raise DatabaseError, "%(S)s: %(M)s" % error def NoticeResponse (self): fieldtype = self.get_struct ("c") error = {} while fieldtype != '\000': error[fieldtype] = self.get_str () fieldtype = self.get_struct ("c") self.__notifications.append (error) def NotificationResponse (self): listen = {} listen['PID'] = self.get_struct ("!L") listen['event'] = self.get_str () listen['param'] = self.get_str () self.__listens.append (listen) def ParameterDescription (self): pass def ParameterStatus (self): p = self.get_str () self.param[p] = self.get_str () def ReadyForQuery (self): self.__status = self.get_struct ("c") def RowDescription (self): rows, = self.get_struct ("!h") self.__description = [] for i in range (0, rows): row = {} row['name'] = self.get_str () row['tableOID'], row['col_no'], row['typeOID'], row['typlen'], row['modifier'], row['format'] = self.get_struct ("!lhlhlh") self.__description.append (row) print self.__description def close (self): self.send ('X') self.sock.close () def query (self, query): self.__buffer = [] self.add_str (query) self.send ('Q') self.getmessage ('Z') return self.__buffer