Added initial unfs3 sources for version 0.9.22+dfsg-1maemo2
[unfs3] / unfs3 / contrib / rpcproxy / rpcproxy
1 #!/usr/bin/env python
2 # -*-mode: python; coding: utf-8 -*-
3
4 # TODO:
5 # Support for limiting data sizes, max number of connections from the same IP etc
6
7 import sys
8 import time
9 import socket
10 import select
11 import struct
12
13 # Connection states, both for client and server connection.
14 # Client cycle: STATE_READING, WAITING, WRITING
15 # Server cycle: WAITING, WRITING, STATE_READING
16 STATE_READING = 0 # Reading record
17 STATE_WAITING = 2 # Waiting for server response callback, or client mission. 
18 STATE_WRITING = 3 # Writing response to client or request to server
19 STATE_EOF = 4 # EOF while reading
20
21 # Constants
22 FRAG_HEADER_LEN = 4
23 FRAG_MAX_SIZE = 2**31 - 1
24 FRAG_SIZE = FRAG_MAX_SIZE # Size of newly created fragments
25
26
27 class ProxyEngine:
28     def __init__(self):
29         self.connections = [] # Client or server connections
30         self.proxies = [] # Proxy objects
31
32
33     def add_proxy(self, bind_address, port, host, hostport):
34         """Add a new proxy"""
35         proxy = Proxy(self, bind_address, port, host, hostport)
36         self.proxies.append(proxy)
37         self.connections.append(proxy.srv)
38
39     def add_connection(self, conn):
40         """Add a new connection"""
41         self.connections.append(conn)
42
43
44 class Proxy:
45     def __init__(self, pe, bind_address, port, host, hostport):
46         self.pe = pe
47         self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
48         self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
49         self.sock.bind((bind_address, port)) 
50         self.sock.listen(1)
51         self.srv = ServerConnection(host, hostport)
52
53
54     def fileno(self):
55         """Return the sockets fileno"""
56         return self.sock.fileno()
57
58
59     def handle_read(self):
60         """Accept a new connection. Return a new ClientConnection"""
61         sock, addr = self.sock.accept()
62         self.pe.add_connection(ClientConnection(sock, addr, self.srv))
63
64
65 class ServerCall:
66     def __init__(self, data, callback):
67         self.data = data
68         self.callback = callback
69
70
71 class RPCConnection:
72     def __init__(self):
73         self.record = "" # Current record, as stream with RMs
74         self.sndbuf = None
75         self.sock = None
76
77
78     def set_sock(self, sock):
79         """Set socket to use"""
80         self.sock = sock
81         self.sndbuf = sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF)
82
83
84     def eof_event(self):
85         """recv/send discovered that the connection was closed"""
86         self.state = STATE_EOF 
87         self.sock.close()
88
89
90     def assert_sock(self):
91         """Make sure a socket is available. May be overridden."""
92         assert self.sock is not None
93
94
95     def readable(self):
96         """Returns true if connection wants to read"""
97         return self.state is STATE_READING
98
99
100     def writable(self):
101         """Returns true if connection wants to write"""
102         return self.state is STATE_WRITING
103
104
105     def eof(self):
106         """Returns true if EOF has been detected"""
107         return self.state is STATE_EOF
108
109
110     def fileno(self):
111         """Return the sockets fileno"""
112         self.assert_sock()
113         return self.sock.fileno()
114
115
116     def write_record(self):
117         """Write RPC record. Returns true when everything is written"""
118         self.assert_sock()
119         # We can write up to SO_SNDBUF without risk blocking
120         wrote = self.sock.send(buffer(self.record, 0, self.sndbuf))
121         self.record = self.record[wrote:]
122         return len(self.record) == 0
123
124
125     def frag_length(self, head):
126         """Return the length of a fragment, including header"""
127         assert len(head) == FRAG_HEADER_LEN
128         x = struct.unpack('>L', head)[0]
129         return int(x & 0x7fffffff) + FRAG_HEADER_LEN
130
131
132     def frag_last(self, head):
133         """Return true if last flag is set"""
134         assert len(head) == FRAG_HEADER_LEN
135         x = struct.unpack('>L', head)[0]
136         return ((x & 0x80000000L) != 0)
137
138
139     def rm_stream(self, stream):
140         """Record-mark a data stream"""
141         fragpos = 0
142         data = []
143
144         while 1:
145             last = (fragpos+FRAG_SIZE >= len(stream))
146             frag_data = buffer(stream, fragpos, FRAG_SIZE)
147             x = len(frag_data)
148             if last:
149                 x = x | 0x80000000L
150             frag_head = struct.pack('>L', x)
151             data.append(frag_head + str(frag_data))
152             if last:
153                 break
154             fragpos += len(frag_data)
155
156         return "".join(data)
157
158
159     def parsed_record(self):
160         """Return tupel (data, missing_bytes) of record"""
161         fragpos = 0
162         data = []
163         while 1:
164             frag = buffer(self.record, fragpos)
165             fraghead = buffer(self.record, fragpos, FRAG_HEADER_LEN)
166             data.append(frag[4:])
167
168             if len(frag) < FRAG_HEADER_LEN:
169                 return ("".join(data), FRAG_HEADER_LEN - len(frag))
170             
171             len_from_head = self.frag_length(fraghead)
172             if len(frag) < len_from_head:
173                 # Incomplete fragment
174                 return ("".join(data), len_from_head - len(frag))
175             elif len(frag) == len_from_head:
176                 # Complete fragment
177                 if self.frag_last(fraghead):
178                     # No need to read anything more
179                     return ("".join(data), 0)
180                 else:
181                     # Read another fragment
182                     return ("".join(data), FRAG_HEADER_LEN)
183             elif len(frag) > len_from_head:
184                 # There are more fragments, check them
185                 fragpos += len(frag)
186             else:
187                 assert 0
188
189
190     def read_record(self):
191         """Read RPC record. Returns true if record complete"""
192         self.assert_sock()
193         assert self.state == STATE_READING
194         bytes_to_read = self.parsed_record()[1]
195         if bytes_to_read == 0:
196             return 1
197         
198         data = self.sock.recv(bytes_to_read)
199
200         if data == "":
201             self.eof_event()
202             return 0
203
204         self.record += data
205         return self.parsed_record()[1] == 0
206
207
208 class ServerConnection(RPCConnection):
209     def __init__(self, host, port):
210         RPCConnection.__init__(self)
211         self.host = host
212         self.port = port
213         self.calls = [] # A list of ServerCalls
214         self.state = STATE_WAITING
215         self.current_cb = None
216
217
218     def eof_event(self):
219         """Overridden eof_event, which re-connects"""
220         print >>sys.stderr, "Lost connection to server, trying to reconnect."
221         # Discard the current call
222         self.current_cb("")
223         self.current_cb = None
224         if self.calls:
225             self.state = STATE_WRITING
226         else:
227             self.state = STATE_WAITING
228
229         # Re-create socket
230         self.sock.close()
231         self.sock = None
232         self.assert_sock()
233
234
235     def assert_sock(self):
236         """Overridden assert_sock, which connects dynamically"""
237         if self.sock is None:
238             srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
239             while 1:
240                 try:
241                     srv_sock.connect((self.host, self.port))
242                     print >>sys.stderr, "Connected to %s:%d" % (self.host, self.port)
243                     break
244                 except socket.error, e:
245                     print >>sys.stderr, "Connection to %s:%d failed: %s" % (self.host, self.port, e[1])
246                     time.sleep(5)
247             self.set_sock(srv_sock)
248
249
250     def call(self, servercall):
251         """Put another call on the call queue. The call argument is a
252         stream, without RMs. The callback will be called with result"""
253         self.calls.append(servercall)
254         if self.state == STATE_WAITING:
255             self.state = STATE_WRITING
256
257
258     def handle_read(self):
259         """Called when socket is ready for read"""
260         if self.read_record():
261             self.current_cb(self.parsed_record()[0])
262             self.current_cb = None
263             if self.calls:
264                 self.state = STATE_WRITING
265             else:
266                 self.state = STATE_WAITING
267
268
269     def handle_write(self):
270         """Called when socket is ready for write"""
271         assert self.state == STATE_WRITING
272         if self.current_cb is None:
273             # Start working on another request
274             servercall = self.calls.pop(0)
275             self.record = self.rm_stream(servercall.data)
276             self.current_cb = servercall.callback
277
278         assert self.current_cb
279         if self.write_record():
280             self.state = STATE_READING
281             self.record = ""
282
283
284 class ClientConnection(RPCConnection):
285     def __init__(self, sock, addr, srv):
286         RPCConnection.__init__(self)
287         self.set_sock(sock)
288         self.addr = addr
289         self.srv = srv
290         self.state = STATE_READING
291
292         
293     def handle_read(self):
294         """Called when socket is ready for read"""
295         if self.read_record():
296             self.state = STATE_WAITING
297             self.srv.call(ServerCall(self.parsed_record()[0], self.got_response))
298
299
300     def handle_write(self):
301         """Called when socket is ready for write"""
302         assert self.state == STATE_WRITING
303         if self.write_record():
304             self.state = STATE_READING
305             self.record = ""
306
307
308     def got_response(self, data):
309         """Callback: We got a response from the server"""
310         # send to client
311         self.state = STATE_WRITING
312         self.record = self.rm_stream(data)
313
314
315 def usage():
316     sys.exit("Usage: %s [bind_address:]port:host:hostport ..." % sys.argv[0])
317
318
319 def parse_arg(arg):
320     """Parse a command argument, specifying hosts and ports.
321     Returns tuple (bind_address, port, host, hostport)"""
322     fields = arg.split(":")
323     if len(fields) == 3:
324         fields.insert(0, "127.0.0.1")
325
326     if len(fields) != 4:
327         usage()
328
329     bind_address, port, host, hostport = fields
330     port = int(port)
331     hostport = int(hostport)
332     return bind_address, port, host, hostport
333
334
335 def main():
336     if len(sys.argv) < 2:
337         usage()
338
339     pe = ProxyEngine()
340
341     #
342     # Determine hosts and ports
343     #
344     for arg in sys.argv[1:]:
345         pe.add_proxy(*parse_arg(arg))
346
347     #
348     # Select loop
349     #
350     while 1:
351         # Set up sets
352         read_set = []
353         read_set.extend(pe.proxies)
354         write_set = []
355         for conn in pe.connections:
356             if conn.readable():
357                 read_set.append(conn)
358             if conn.writable():
359                 write_set.append(conn)
360
361         rlist, wlist, xlist = select.select(read_set, write_set, [])
362
363         for obj in rlist:
364             obj.handle_read()
365
366         for obj in wlist:
367             obj.handle_write()
368
369         for conn in pe.connections:
370             if conn.eof():
371                 pe.connections.remove(conn)
372
373
374 if __name__ == "__main__":
375     try:
376         main()
377     except KeyboardInterrupt:
378         sys.exit(0)