Added initial unfs3 sources for version 0.9.22+dfsg-1maemo2
[unfs3] / unfs3 / contrib / nfsotpclient / rpc.py
1
2 # rpc.py - RFC1057/RFC1831
3 #
4 # Copyright (C) 2001 Cendio Systems AB (http://www.cendio.se)
5 # All rights reserved.
6
7 # Copyright (c) 2001 Python Software Foundation.
8 # All rights reserved.
9
10 # Copyright (c) 2000 BeOpen.com.
11 # All rights reserved.
12
13 # Copyright (c) 1995-2001 Corporation for National Research Initiatives.
14 # All rights reserved.
15
16 # Copyright (c) 1991-1995 Stichting Mathematisch Centrum.
17 # All rights reserved.
18
19 #
20 # This program is free software; you can redistribute it and/or modify
21 # it under the terms of the GNU General Public License as published by
22 # the Free Software Foundation; version 2 of the License. 
23
24 # This program is distributed in the hope that it will be useful,
25 # but WITHOUT ANY WARRANTY; without even the implied warranty of
26 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
27 # GNU General Public License for more details.
28
29 # You should have received a copy of the GNU General Public License
30 # along with this program; if not, write to the Free Software
31 # Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
32
33 # XXX The UDP version of the protocol resends requests when it does
34 # XXX not receive a timely reply -- use only for idempotent calls!
35
36 # XXX There is no provision for call timeout on TCP connections
37
38 __pychecker__ = 'no-callinit'
39
40 import xdrlib
41 import socket
42 import os
43 import time
44
45 RPCVERSION = 2
46
47 CALL = 0
48 REPLY = 1
49
50 AUTH_NULL = 0
51 AUTH_UNIX = 1
52 AUTH_SHORT = 2
53 AUTH_DES = 3
54
55 MSG_ACCEPTED = 0
56 MSG_DENIED = 1
57
58 SUCCESS = 0                             # RPC executed successfully
59 PROG_UNAVAIL  = 1                       # remote hasn't exported program
60 PROG_MISMATCH = 2                       # remote can't support version #
61 PROC_UNAVAIL  = 3                       # program can't support procedure
62 GARBAGE_ARGS  = 4                       # procedure can't decode params
63
64 RPC_MISMATCH = 0                        # RPC version number != 2
65 AUTH_ERROR = 1                          # remote can't authenticate caller
66
67 AUTH_BADCRED      = 1                   # bad credentials (seal broken)
68 AUTH_REJECTEDCRED = 2                   # client must begin new session
69 AUTH_BADVERF      = 3                   # bad verifier (seal broken)
70 AUTH_REJECTEDVERF = 4                   # verifier expired or replayed
71 AUTH_TOOWEAK      = 5                   # rejected for security reasons
72
73 # All RPC errors are subclasses of RPCException
74 class RPCException(Exception):
75     def __str__(self):
76         return "RPCException"
77
78 class BadRPCMsgType(RPCException):
79     def __init__(self, msg_type):
80         self.msg_type = msg_type
81
82     def __str__(self):
83         return "BadRPCMsgType %d" % self.msg_type
84
85 class BadRPCVersion(RPCException):
86     def __init__(self, version):
87         self.version = version
88
89     def __str__(self):
90         return "BadRPCVersion %d" % self.version
91
92 class RPCMsgDenied(RPCException):
93     # MSG_DENIED
94     def __init__(self, stat):
95         self.stat = stat
96
97     def __str__(self):
98         return "RPCMsgDenied %d" % self.stat
99
100 class RPCMisMatch(RPCException):
101     # MSG_DENIED: RPC_MISMATCH
102     def __init__(self, low, high):
103         self.low = low
104         self.high = high
105
106     def __str__(self):
107         return "RPCMisMatch %d,%d" % (self.low, self.high)
108
109 class RPCAuthError(RPCException):
110     # MSG_DENIED: AUTH_ERROR
111     def __init__(self, stat):
112         self.stat = stat
113
114     def __str__(self):
115         return "RPCAuthError %d" % self.stat
116
117 class BadRPCReplyType(RPCException):
118     # Neither MSG_DENIED nor MSG_ACCEPTED
119     def __init__(self, msg_type):
120         self.msg_type = msg_type
121
122     def __str__(self):
123         return "BadRPCReplyType %d" % self.msg_type
124
125 class RPCProgUnavail(RPCException):
126     # PROG_UNAVAIL
127     def __str__(self):
128         return "RPCProgUnavail"
129
130 class RPCProgMismatch(RPCException):
131     # PROG_MISMATCH
132     def __init__(self, low, high):
133         self.low = low
134         self.high = high
135
136     def __str__(self):
137         return "RPCProgMismatch %d,%d" % (self.low, self.high)
138
139 class RPCProcUnavail(RPCException):
140     # PROC_UNAVAIL
141     def __str__(self):
142         return "RPCProcUnavail"
143
144 class RPCGarbageArgs(RPCException):
145     # GARBAGE_ARGS
146     def __str__(self):
147         return "RPCGarbageArgs"
148
149 class RPCUnextractedData(RPCException):
150     # xdrlib raised exception because unextracted data remained
151     def __str__(self):
152         return "RPCUnextractedData"
153
154 class RPCBadAcceptStats(RPCException):
155     # Unknown accept_stat
156     def __init__(self, stat):
157         self.stat = stat
158
159     def __str__(self):
160         return "RPCBadAcceptStats %d" % self.stat
161
162 class XidMismatch(RPCException):
163     # Got wrong XID in reply, got "xid" instead of "expected"
164     def __init__(self, xid, expected):
165         self.xid = xid
166         self.expected = expected
167
168     def __str__(self):
169         return "XidMismatch %d,%d" % (self.xid, self.expected)
170
171 class TimeoutError(RPCException):
172     pass
173
174 class PortMapError(RPCException):
175     pass
176
177
178 class Packer(xdrlib.Packer):
179
180     def pack_auth(self, auth):
181         flavor, stuff = auth
182         self.pack_enum(flavor)
183         self.pack_opaque(stuff)
184
185     def pack_auth_unix(self, stamp, machinename, uid, gid, gids):
186         self.pack_uint(stamp)
187         self.pack_string(machinename)
188         self.pack_uint(uid)
189         self.pack_uint(gid)
190         self.pack_uint(len(gids))
191         for i in gids:
192             self.pack_uint(i)
193
194     def pack_callheader(self, xid, prog, vers, proc, cred, verf):
195         self.pack_uint(xid)
196         self.pack_enum(CALL)
197         self.pack_uint(RPCVERSION)
198         self.pack_uint(prog)
199         self.pack_uint(vers)
200         self.pack_uint(proc)
201         self.pack_auth(cred)
202         self.pack_auth(verf)
203         # Caller must add procedure-specific part of call
204
205     def pack_replyheader(self, xid, verf):
206         self.pack_uint(xid)
207         self.pack_enum(REPLY)
208         self.pack_uint(MSG_ACCEPTED)
209         self.pack_auth(verf)
210         self.pack_enum(SUCCESS)
211         # Caller must add procedure-specific part of reply
212
213
214 class Unpacker(xdrlib.Unpacker):
215
216     def unpack_auth(self):
217         flavor = self.unpack_enum()
218         stuff = self.unpack_opaque()
219         if flavor == AUTH_UNIX:
220                 p = Unpacker(stuff)
221                 stuff = p.unpack_auth_unix()
222         return (flavor, stuff)
223
224     def unpack_auth_unix(self):
225         stamp=self.unpack_uint()
226         machinename=self.unpack_string()
227         print "machinename: %s" % machinename
228         uid=self.unpack_uint()
229         gid=self.unpack_uint()
230         n_gids=self.unpack_uint()
231         gids = []
232         print "n_gids: %d" % n_gids
233         for i in range(n_gids):
234             gids.append(self.unpack_uint())
235         return stamp, machinename, uid, gid, gids
236
237
238     def unpack_callheader(self):
239         xid = self.unpack_uint()
240         msg_type = self.unpack_enum()
241         if msg_type <> CALL:
242             raise BadRPCMsgType(msg_type)
243         rpcvers = self.unpack_uint()
244         if rpcvers <> RPCVERSION:
245             raise BadRPCVersion(rpcvers)
246         prog = self.unpack_uint()
247         vers = self.unpack_uint()
248         proc = self.unpack_uint()
249         cred = self.unpack_auth()
250         verf = self.unpack_auth()
251         return xid, prog, vers, proc, cred, verf
252         # Caller must add procedure-specific part of call
253
254     def unpack_replyheader(self):
255         xid = self.unpack_uint()
256         msg_type = self.unpack_enum()
257         if msg_type <> REPLY:
258             raise BadRPCMsgType(msg_type)
259         stat = self.unpack_enum()
260         if stat == MSG_DENIED:
261             stat = self.unpack_enum()
262             if stat == RPC_MISMATCH:
263                 low = self.unpack_uint()
264                 high = self.unpack_uint()
265                 raise RPCMisMatch(low, high)
266             if stat == AUTH_ERROR:
267                 stat = self.unpack_uint()
268                 raise RPCAuthError(stat)
269             raise RPCMsgDenied(stat)
270         if stat <> MSG_ACCEPTED:
271             raise BadRPCReplyType(stat)
272         verf = self.unpack_auth()
273         stat = self.unpack_enum()
274         if stat == PROG_UNAVAIL:
275             raise RPCProgUnavail()
276         if stat == PROG_MISMATCH:
277             low = self.unpack_uint()
278             high = self.unpack_uint()
279             raise RPCProgMismatch(low, high)
280         if stat == PROC_UNAVAIL:
281             raise RPCProcUnavail()
282         if stat == GARBAGE_ARGS:
283             raise RPCGarbageArgs()
284         if stat <> SUCCESS:
285             raise RPCBadAcceptStats(stat)
286         return xid, verf
287         # Caller must get procedure-specific part of reply
288
289
290 # Subroutines to create opaque authentication objects
291
292 def make_auth_null():
293     return ''
294
295 def make_auth_unix(seed, host, uid, gid, groups):
296     p = Packer()
297     p.pack_auth_unix(seed, host, uid, gid, groups)
298     return p.get_buffer()
299
300 def make_auth_unix_default():
301     try:
302         uid = os.getuid()
303         gid = os.getgid()
304     except AttributeError:
305         uid = gid = 0
306     return make_auth_unix(int(time.time()-unix_epoch()), \
307               socket.gethostname(), uid, gid, [])
308
309 _unix_epoch = -1
310 def unix_epoch():
311     """Very painful calculation of when the Unix Epoch is.
312
313     This is defined as the return value of time.time() on Jan 1st,
314     1970, 00:00:00 GMT.
315
316     On a Unix system, this should always return 0.0.  On a Mac, the
317     calculations are needed -- and hard because of integer overflow
318     and other limitations.
319
320     """
321     global _unix_epoch
322     if _unix_epoch >= 0: return _unix_epoch
323     now = time.time()
324     localt = time.localtime(now)        # (y, m, d, hh, mm, ss, ..., ..., ...)
325     gmt = time.gmtime(now)
326     offset = time.mktime(localt) - time.mktime(gmt)
327     y, m, d, hh, mm, ss = 1970, 1, 1, 0, 0, 0
328     offset, ss = divmod(ss + offset, 60)
329     offset, mm = divmod(mm + offset, 60)
330     offset, hh = divmod(hh + offset, 24)
331     d = d + offset
332     _unix_epoch = time.mktime((y, m, d, hh, mm, ss, 0, 0, 0))
333     print "Unix epoch:", time.ctime(_unix_epoch)
334     return _unix_epoch
335
336
337 # Common base class for clients
338
339 class Client:
340
341     def __init__(self, host, prog, vers, port):
342         self.host = host
343         self.prog = prog
344         self.vers = vers
345         self.port = port
346         self.sock = None
347         self.makesocket() # Assigns to self.sock
348         self.bindsocket()
349         self.connsocket()
350         # Servers may do XID caching, so try to come up with something
351         # unique to start with. XIDs are 32 bits. Python integers are always
352         # at least 32 bits. 
353         self.lastxid = int(long(time.time() * 1E6) & 0xfffffff)
354         self.addpackers()
355         self.cred = None
356         self.verf = None
357
358     def close(self):
359         self.sock.close()
360
361     def makesocket(self):
362         # This MUST be overridden
363         raise RuntimeError("makesocket not defined")
364
365     def connsocket(self):
366         # Override this if you don't want/need a connection
367         self.sock.connect((self.host, self.port))
368
369     def bindsocket(self):
370         # Override this to bind to a different port (e.g. reserved)
371         self.sock.bind(('', 0))
372
373     def addpackers(self):
374         # Override this to use derived classes from Packer/Unpacker
375         self.packer = Packer()
376         self.unpacker = Unpacker('')
377
378     def make_call(self, proc, args, pack_func, unpack_func):
379         # Don't normally override this (but see Broadcast)
380         if pack_func is None and args is not None:
381             raise TypeError("non-null args with null pack_func")
382         self.start_call(proc)
383         if pack_func:
384             pack_func(args)
385         self.do_call()
386         if unpack_func:
387             result = unpack_func()
388         else:
389             result = None
390         try:
391             self.unpacker.done()
392         except xdrlib.Error:
393             raise RPCUnextractedData()
394             
395         return result
396
397     def start_call(self, proc):
398         # Don't override this
399         self.lastxid = xid = self.lastxid + 1
400         cred = self.mkcred()
401         verf = self.mkverf()
402         p = self.packer
403         p.reset()
404         p.pack_callheader(xid, self.prog, self.vers, proc, cred, verf)
405
406     def do_call(self):
407         # This MUST be overridden
408         raise RuntimeError("do_call not defined")
409
410     def mkcred(self):
411         # Override this to use more powerful credentials
412         if self.cred == None:
413             self.cred = (AUTH_NULL, make_auth_null())
414         return self.cred
415
416     def mkverf(self):
417         # Override this to use a more powerful verifier
418         if self.verf == None:
419             self.verf = (AUTH_NULL, make_auth_null())
420         return self.verf
421
422     def call_0(self):           # Procedure 0 is always like this
423         return self.make_call(0, None, None, None)
424
425
426 # Record-Marking standard support
427
428 def sendfrag(sock, last, frag):
429     x = len(frag)
430     if last: x = x | 0x80000000L
431     header = (chr(int(x>>24 & 0xff)) + chr(int(x>>16 & 0xff)) + \
432               chr(int(x>>8 & 0xff)) + chr(int(x & 0xff)))
433     sock.send(header + frag)
434
435 def sendrecord(sock, record):
436     sendfrag(sock, 1, record)
437
438 def recvfrag(sock):
439     header = sock.recv(4)
440     if len(header) < 4:
441         raise EOFError
442     x = long(ord(header[0]))<<24 | ord(header[1])<<16 | \
443         ord(header[2])<<8 | ord(header[3])
444     last = ((x & 0x80000000) != 0)
445     n = int(x & 0x7fffffff)
446     frag = ''
447     while n > 0:
448         buf = sock.recv(n)
449         if not buf: raise EOFError
450         n = n - len(buf)
451         frag = frag + buf
452     return last, frag
453
454 def recvrecord(sock):
455     record = ''
456     last = 0
457     while not last:
458         last, frag = recvfrag(sock)
459         record = record + frag
460     return record
461
462
463 # Try to bind to a reserved port (must be root)
464
465 last_resv_port_tried = None
466 def bindresvport(sock, host):
467     global last_resv_port_tried
468     FIRST, LAST = 600, 1024 # Range of ports to try
469     if last_resv_port_tried == None:
470         last_resv_port_tried = FIRST + os.getpid() % (LAST-FIRST)
471     for i in range(last_resv_port_tried, LAST) + \
472               range(FIRST, last_resv_port_tried):
473         last_resv_port_tried = i
474         try:
475             sock.bind((host, i))
476             return last_resv_port_tried
477         except socket.error, (errno, msg):
478             if errno <> 114:
479                 raise socket.error(errno, msg)
480     raise RuntimeError("can't assign reserved port")
481
482
483 # Client using TCP to a specific port
484
485 class RawTCPClient(Client):
486
487     def makesocket(self):
488         self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
489
490     def do_call(self):
491         call = self.packer.get_buffer()
492         sendrecord(self.sock, call)
493         reply = recvrecord(self.sock)
494         u = self.unpacker
495         u.reset(reply)
496         xid, verf = u.unpack_replyheader()
497         if xid <> self.lastxid:
498             # Can't really happen since this is TCP...
499             raise XidMismatch(xid, self.lastxid)
500
501 # Client using UDP to a specific port
502
503 class RawUDPClient(Client):
504
505     def makesocket(self):
506         self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
507
508     def do_call(self):
509         call = self.packer.get_buffer()
510         self.sock.send(call)
511         try:
512             from select import select
513         except ImportError:
514             print 'WARNING: select not found, RPC may hang'
515             select = None
516         BUFSIZE = 8192 # Max UDP buffer size
517         timeout = 1
518         count = 5
519         while 1:
520             r, w, x = [self.sock], [], []
521             if select:
522                 r, w, x = select(r, w, x, timeout)
523             if self.sock not in r:
524                 count = count - 1
525                 if count < 0: raise TimeoutError() 
526                 if timeout < 25: timeout = timeout *2
527 ##                              print 'RESEND', timeout, count
528                 self.sock.send(call)
529                 continue
530             reply = self.sock.recv(BUFSIZE)
531             u = self.unpacker
532             u.reset(reply)
533             xid, verf = u.unpack_replyheader()
534             if xid <> self.lastxid:
535 ##                              print 'BAD xid'
536                 continue
537             break
538
539
540 # Client using UDP broadcast to a specific port
541
542 class RawBroadcastUDPClient(RawUDPClient):
543
544     def __init__(self, bcastaddr, prog, vers, port):
545         RawUDPClient.__init__(self, bcastaddr, prog, vers, port)
546         self.reply_handler = None
547         self.timeout = 30
548
549     def connsocket(self):
550         # Don't connect -- use sendto
551         self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
552
553     def set_reply_handler(self, reply_handler):
554         self.reply_handler = reply_handler
555
556     def set_timeout(self, timeout):
557         self.timeout = timeout # Use None for infinite timeout
558
559     def make_call(self, proc, args, pack_func, unpack_func):
560         if pack_func is None and args is not None:
561             raise TypeError("non-null args with null pack_func")
562         self.start_call(proc)
563         if pack_func:
564             pack_func(args)
565         call = self.packer.get_buffer()
566         self.sock.sendto(call, (self.host, self.port))
567         try:
568             from select import select
569         except ImportError:
570             print 'WARNING: select not found, broadcast will hang'
571             select = None
572         BUFSIZE = 8192 # Max UDP buffer size (for reply)
573         replies = []
574         if unpack_func is None:
575             def dummy(): pass
576             unpack_func = dummy
577         while 1:
578             r, w, x = [self.sock], [], []
579             if select:
580                 if self.timeout is None:
581                     r, w, x = select(r, w, x)
582                 else:
583                     r, w, x = select(r, w, x, self.timeout)
584             if self.sock not in r:
585                 break
586             reply, fromaddr = self.sock.recvfrom(BUFSIZE)
587             u = self.unpacker
588             u.reset(reply)
589             xid, verf = u.unpack_replyheader()
590             if xid <> self.lastxid:
591 ##                              print 'BAD xid'
592                 continue
593             reply = unpack_func()
594             try:
595                 self.unpacker.done()
596             except xdrlib.Error:
597                 raise RPCUnextractedData()
598             replies.append((reply, fromaddr))
599             if self.reply_handler:
600                 self.reply_handler(reply, fromaddr)
601         return replies
602
603
604 # Port mapper interface
605
606 # Program number, version and (fixed!) port number
607 PMAP_PROG = 100000
608 PMAP_VERS = 2
609 PMAP_PORT = 111
610
611 # Procedure numbers
612 PMAPPROC_NULL = 0                       # (void) -> void
613 PMAPPROC_SET = 1                        # (mapping) -> bool
614 PMAPPROC_UNSET = 2                      # (mapping) -> bool
615 PMAPPROC_GETPORT = 3                    # (mapping) -> unsigned int
616 PMAPPROC_DUMP = 4                       # (void) -> pmaplist
617 PMAPPROC_CALLIT = 5                     # (call_args) -> call_result
618
619 # A mapping is (prog, vers, prot, port) and prot is one of:
620
621 IPPROTO_TCP = 6
622 IPPROTO_UDP = 17
623
624 # A pmaplist is a variable-length list of mappings, as follows:
625 # either (1, mapping, pmaplist) or (0).
626
627 # A call_args is (prog, vers, proc, args) where args is opaque;
628 # a call_result is (port, res) where res is opaque.
629
630
631 class PortMapperPacker(Packer):
632
633     def pack_mapping(self, mapping):
634         prog, vers, prot, port = mapping
635         self.pack_uint(prog)
636         self.pack_uint(vers)
637         self.pack_uint(prot)
638         self.pack_uint(port)
639
640     def pack_pmaplist(self, list):
641         self.pack_list(list, self.pack_mapping)
642
643     def pack_call_args(self, ca):
644         prog, vers, proc, args = ca
645         self.pack_uint(prog)
646         self.pack_uint(vers)
647         self.pack_uint(proc)
648         self.pack_opaque(args)
649
650
651 class PortMapperUnpacker(Unpacker):
652
653     def unpack_mapping(self):
654         prog = self.unpack_uint()
655         vers = self.unpack_uint()
656         prot = self.unpack_uint()
657         port = self.unpack_uint()
658         return prog, vers, prot, port
659
660     def unpack_pmaplist(self):
661         return self.unpack_list(self.unpack_mapping)
662
663     def unpack_call_result(self):
664         port = self.unpack_uint()
665         res = self.unpack_opaque()
666         return port, res
667
668
669 class PartialPortMapperClient:
670     __pychecker__ = 'no-classattr'
671     def addpackers(self):
672         self.packer = PortMapperPacker()
673         self.unpacker = PortMapperUnpacker('')
674
675     def Set(self, mapping):
676         return self.make_call(PMAPPROC_SET, mapping, \
677                 self.packer.pack_mapping, \
678                 self.unpacker.unpack_uint)
679
680     def Unset(self, mapping):
681         return self.make_call(PMAPPROC_UNSET, mapping, \
682                 self.packer.pack_mapping, \
683                 self.unpacker.unpack_uint)
684
685     def Getport(self, mapping):
686         return self.make_call(PMAPPROC_GETPORT, mapping, \
687                 self.packer.pack_mapping, \
688                 self.unpacker.unpack_uint)
689
690     def Dump(self):
691         return self.make_call(PMAPPROC_DUMP, None, \
692                 None, \
693                 self.unpacker.unpack_pmaplist)
694
695     def Callit(self, ca):
696         return self.make_call(PMAPPROC_CALLIT, ca, \
697                 self.packer.pack_call_args, \
698                 self.unpacker.unpack_call_result)
699
700
701 class TCPPortMapperClient(PartialPortMapperClient, RawTCPClient):
702
703     def __init__(self, host):
704         RawTCPClient.__init__(self, \
705                 host, PMAP_PROG, PMAP_VERS, PMAP_PORT)
706
707
708 class UDPPortMapperClient(PartialPortMapperClient, RawUDPClient):
709
710     def __init__(self, host):
711         RawUDPClient.__init__(self, \
712                 host, PMAP_PROG, PMAP_VERS, PMAP_PORT)
713
714
715 class BroadcastUDPPortMapperClient(PartialPortMapperClient, \
716                                    RawBroadcastUDPClient):
717
718     def __init__(self, bcastaddr):
719         RawBroadcastUDPClient.__init__(self, \
720                 bcastaddr, PMAP_PROG, PMAP_VERS, PMAP_PORT)
721
722
723 # Generic clients that find their server through the Port mapper
724
725 class TCPClient(RawTCPClient):
726
727     def __init__(self, host, prog, vers):
728         pmap = TCPPortMapperClient(host)
729         port = pmap.Getport((prog, vers, IPPROTO_TCP, 0))
730         pmap.close()
731         if port == 0:
732             raise PortMapError("program not registered")
733         RawTCPClient.__init__(self, host, prog, vers, port)
734
735
736 class UDPClient(RawUDPClient):
737
738     def __init__(self, host, prog, vers):
739         pmap = UDPPortMapperClient(host)
740         port = pmap.Getport((prog, vers, IPPROTO_UDP, 0))
741         pmap.close()
742         if port == 0:
743             raise PortMapError("program not registered")
744         RawUDPClient.__init__(self, host, prog, vers, port)
745
746
747 class BroadcastUDPClient(Client):
748
749     def __init__(self, bcastaddr, prog, vers):
750         self.pmap = BroadcastUDPPortMapperClient(bcastaddr)
751         self.pmap.set_reply_handler(self.my_reply_handler)
752         self.prog = prog
753         self.vers = vers
754         self.user_reply_handler = None
755         self.addpackers()
756
757     def close(self):
758         self.pmap.close()
759
760     def set_reply_handler(self, reply_handler):
761         self.user_reply_handler = reply_handler
762
763     def set_timeout(self, timeout):
764         self.pmap.set_timeout(timeout)
765
766     def my_reply_handler(self, reply, fromaddr):
767         port, res = reply
768         self.unpacker.reset(res)
769         result = self.unpack_func()
770         try:
771             self.unpacker.done()
772         except xdrlib.Error:
773             raise RPCUnextractedData()
774         self.replies.append((result, fromaddr))
775         if self.user_reply_handler is not None:
776             self.user_reply_handler(result, fromaddr)
777
778     def make_call(self, proc, args, pack_func, unpack_func):
779         self.packer.reset()
780         if pack_func:
781             pack_func(args)
782         if unpack_func is None:
783             def dummy(): pass
784             self.unpack_func = dummy
785         else:
786             self.unpack_func = unpack_func
787         self.replies = []
788         packed_args = self.packer.get_buffer()
789         dummy_replies = self.pmap.Callit( \
790                 (self.prog, self.vers, proc, packed_args))
791         return self.replies
792
793
794 # Server classes
795
796 # These are not symmetric to the Client classes
797 # XXX No attempt is made to provide authorization hooks yet
798
799 class Server:
800
801     def __init__(self, host, prog, vers, port):
802         self.host = host # Should normally be '' for default interface
803         self.prog = prog
804         self.vers = vers
805         self.port = port # Should normally be 0 for random port
806         self.sock = None
807         self.prot = None
808         self.makesocket() # Assigns to self.sock and self.prot
809         self.bindsocket()
810         self.host, self.port = self.sock.getsockname()
811         self.addpackers()
812
813     def register(self):
814         mapping = self.prog, self.vers, self.prot, self.port
815         p = TCPPortMapperClient(self.host)
816         if not p.Set(mapping):
817             raise PortMapError("register failed")
818
819     def unregister(self):
820         mapping = self.prog, self.vers, self.prot, self.port
821         p = TCPPortMapperClient(self.host)
822         if not p.Unset(mapping):
823             raise PortMapError("unregister failed")
824
825     def handle(self, call):
826         # Don't use unpack_header but parse the header piecewise
827         # XXX I have no idea if I am using the right error responses!
828         self.unpacker.reset(call)
829         self.packer.reset()
830         xid = self.unpacker.unpack_uint()
831         self.packer.pack_uint(xid)
832         temp = self.unpacker.unpack_enum()
833         if temp <> CALL:
834             return None # Not worthy of a reply
835         self.packer.pack_uint(REPLY)
836         temp = self.unpacker.unpack_uint()
837         if temp <> RPCVERSION:
838             self.packer.pack_uint(MSG_DENIED)
839             self.packer.pack_uint(RPC_MISMATCH)
840             self.packer.pack_uint(RPCVERSION)
841             self.packer.pack_uint(RPCVERSION)
842             return self.packer.get_buffer()
843         self.packer.pack_uint(MSG_ACCEPTED)
844         self.packer.pack_auth((AUTH_NULL, make_auth_null()))
845         prog = self.unpacker.unpack_uint()
846         if prog <> self.prog:
847             self.packer.pack_uint(PROG_UNAVAIL)
848             return self.packer.get_buffer()
849         vers = self.unpacker.unpack_uint()
850         if vers <> self.vers:
851             self.packer.pack_uint(PROG_MISMATCH)
852             self.packer.pack_uint(self.vers)
853             self.packer.pack_uint(self.vers)
854             return self.packer.get_buffer()
855         proc = self.unpacker.unpack_uint()
856         methname = 'handle_' + `proc`
857         try:
858             meth = getattr(self, methname)
859         except AttributeError:
860             self.packer.pack_uint(PROC_UNAVAIL)
861             return self.packer.get_buffer()
862         self.recv_cred = self.unpacker.unpack_auth()
863         self.recv_verf = self.unpacker.unpack_auth()
864         try:
865             meth() # Unpack args, call turn_around(), pack reply
866         except (EOFError, RPCGarbageArgs):
867             # Too few or too many arguments
868             self.packer.reset()
869             self.packer.pack_uint(xid)
870             self.packer.pack_uint(REPLY)
871             self.packer.pack_uint(MSG_ACCEPTED)
872             self.packer.pack_auth((AUTH_NULL, make_auth_null()))
873             self.packer.pack_uint(GARBAGE_ARGS)
874         return self.packer.get_buffer()
875
876     def turn_around(self):
877         try:
878             self.unpacker.done()
879         except xdrlib.Error:
880             raise RPCUnextractedData()
881         
882         self.packer.pack_uint(SUCCESS)
883
884     def handle_0(self): # Handle NULL message
885         self.turn_around()
886
887     def makesocket(self):
888         # This MUST be overridden
889         raise RuntimeError("makesocket not defined")
890
891     def bindsocket(self):
892         # Override this to bind to a different port (e.g. reserved)
893         self.sock.bind((self.host, self.port))
894
895     def addpackers(self):
896         # Override this to use derived classes from Packer/Unpacker
897         self.packer = Packer()
898         self.unpacker = Unpacker('')
899
900
901 class TCPServer(Server):
902
903     def makesocket(self):
904         self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
905         self.prot = IPPROTO_TCP
906
907     def loop(self):
908         self.sock.listen(0)
909         while 1:
910             self.session(self.sock.accept())
911
912     def session(self, connection):
913         sock, (host, port) = connection
914         while 1:
915             try:
916                 call = recvrecord(sock)
917             except EOFError:
918                 break
919             except socket.error, msg:
920                 print 'socket error:', msg
921                 break
922             reply = self.handle(call)
923             if reply is not None:
924                 sendrecord(sock, reply)
925
926     def forkingloop(self):
927         # Like loop but uses forksession()
928         self.sock.listen(0)
929         while 1:
930             self.forksession(self.sock.accept())
931
932     def forksession(self, connection):
933         # Like session but forks off a subprocess
934         # Wait for deceased children
935         try:
936             while 1:
937                 pid, sts = os.waitpid(0, 1)
938         except os.error:
939             pass
940         pid = None
941         try:
942             pid = os.fork()
943             if pid: # Parent
944                 connection[0].close()
945                 return
946             # Child
947             self.session(connection)
948         finally:
949             # Make sure we don't fall through in the parent
950             if pid == 0:
951                 os._exit(0)
952
953
954 class UDPServer(Server):
955
956     def makesocket(self):
957         self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
958         self.prot = IPPROTO_UDP
959
960     def loop(self):
961         while 1:
962             self.session()
963
964     def session(self):
965         call, host_port = self.sock.recvfrom(8192)
966         self.sender_port = host_port
967         reply = self.handle(call)
968         if reply <> None:
969             self.sock.sendto(reply, host_port)
970
971
972 # Simple test program -- dump local portmapper status
973
974 def test():
975     pmap = UDPPortMapperClient('')
976     list = pmap.Dump()
977     list.sort()
978     for prog, vers, prot, port in list:
979         print prog, vers,
980         if prot == IPPROTO_TCP: print 'tcp',
981         elif prot == IPPROTO_UDP: print 'udp',
982         else: print prot,
983         print port
984
985
986 # Test program for broadcast operation -- dump everybody's portmapper status
987
988 def testbcast():
989     import sys
990     if sys.argv[1:]:
991         bcastaddr = sys.argv[1]
992     else:
993         bcastaddr = '<broadcast>'
994     def rh(reply, fromaddr):
995         host, port = fromaddr
996         print host + '\t' + `reply`
997     pmap = BroadcastUDPPortMapperClient(bcastaddr)
998     pmap.set_reply_handler(rh)
999     pmap.set_timeout(5)
1000     unused_replies = pmap.Getport((100002, 1, IPPROTO_UDP, 0))
1001
1002
1003 # Test program for server, with corresponding client
1004 # On machine A: python -c 'import rpc; rpc.testsvr()'
1005 # On machine B: python -c 'import rpc; rpc.testclt()' A
1006 # (A may be == B)
1007
1008 def testsvr():
1009     # Simple test class -- proc 1 doubles its string argument as reply
1010     class S(UDPServer):
1011         def handle_1(self):
1012             arg = self.unpacker.unpack_string()
1013             self.turn_around()
1014             print 'RPC function 1 called, arg', `arg`
1015             self.packer.pack_string(arg + arg)
1016     #
1017     s = S('', 0x20000000, 1, 0)
1018     try:
1019         s.unregister()
1020     except PortMapError, e:
1021         print 'RuntimeError:', e.args, '(ignored)'
1022     s.register()
1023     print 'Service started...'
1024     try:
1025         s.loop()
1026     finally:
1027         s.unregister()
1028         print 'Service interrupted.'
1029
1030
1031 def testclt():
1032     import sys
1033     if sys.argv[1:]: host = sys.argv[1]
1034     else: host = ''
1035     # Client for above server
1036     class C(UDPClient):
1037         def call_1(self, arg):
1038             return self.make_call(1, arg, \
1039                     self.packer.pack_string, \
1040                     self.unpacker.unpack_string)
1041     c = C(host, 0x20000000, 1)
1042     print 'making call...'
1043     reply = c.call_1('hello, world, ')
1044     print 'call returned', `reply`
1045
1046
1047 # Local variables:
1048 # py-indent-offset: 4
1049 # tab-width: 8
1050 # End: