在 Python 中获取一个未绑定的端口

本文翻译自 Python Getting A Free Port Number : A Multiprocess-safe Recipe

写本文的目的是介绍如何在 Python 中实现一个函数:get_free_port 返回一个未使用的端口号,并且这个函数支持在多线程和多进程环境中使用。

这也就意味着 get_free_port 可以在任何时候任意地方调用返回的端口号都能被绑定。

事实上,这个目标在逻辑上并不能实现,因为这个函数并不知道它返回的端口号是否已经被使用了,所以最佳的实践方式就是记住所有返回的端口号,每次调用都检查端口号是否已经输出过。

如果我们再增加一个函数用来释放端口就显得更符合逻辑了。

在 Python 中调用 socket.bind(('', 0)) 会自动绑定一个端口号,所以我们可以借助这一特性来实现:

1
2
3
4
5
6
7
8
import socket 

def get_free_port():
sock = socket.socket()
sock.bind(('', 0))
ip, port = sock.getnameinfo()
sock.close()
return port

这个函数看上去能工作,但是还远远不够。采用这个方法只能让端口在非常短的时间内不被绑定,难以满足在竞态场景中使用。

由此不难看出,获得一个未被绑定的端口号不难,难的是如何把这个端口号安全正确地返回给调用者。鉴于此,我们需要找到一种机制可以首先保证这个未绑定的端口号不被肆意绑定:

1
get a free port -> look at dictionary (and lock file) -> bind a free port -> write a dictionary (and lock file) -> release port -> return the port

使用 lock file 能够保证即使端口号未绑定,在未拿到锁之前是不会被其他进程绑定的,在这里使用了。全部代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# freeport.py
import fasteners
import threading

class BindFreePort(object):
def __init__(self, start, stop):
self.port = None

import random, socket

self.sock = socket.socket()

while True:
port = random.randint(start, stop)
try:
self.sock.bind(('', port))
self.port = port
break
except Exception:
continue

def release(self):
assert self.port is not None
self.sock.close()


class FreePort(object):
used_ports = set()

def __init__(self, start=4000, stop=6000):
self.lock = None
self.bind = None
self.port = None

from fasteners.process_lock import InterProcessLock
import time
while True:
bind = BindFreePort(start, stop)

if bind.port in self.used_ports:
bind.release()
continue

'''
Since we cannot be certain the user will bind the port 'immediately' (actually it is not possible using
this flow. We must ensure that the port will not be reacquired even it is not bound to anything
'''
lock = InterProcessLock(path='/tmp/socialdna/port_{}_lock'.format(bind.port))
success = lock.acquire(blocking=False)

if success:
self.lock = lock
self.port = bind.port
self.used_ports.add(bind.port)
bind.release()
break

bind.release()
time.sleep(0.01)

def release(self):
assert self.lock is not None
assert self.port is not None
self.used_ports.remove(self.port)
self.lock.release()

测试如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#freeport_test.py
from freeport import FreePort

def get_and_bind_freeport(*args):
freeport = FreePort(start=4000, stop=4009)
import time
time.sleep(1)
return freeport.port

class FreePortClassTest(unittest.TestCase):
def test_one_port(self):
freeport = FreePort(start=4000, stop=4000)
self.assertEqual(freeport.port, 4000)
freeport.release()

def test_many_ports(self):
freeport = FreePort(start=4000, stop=4000)
self.assertEqual(freeport.port, 4000)
freeport.release()
freeport = FreePort(start=4000, stop=4000)
self.assertEqual(freeport.port, 4000)
freeport.release()

def test_many_ports_conflict(self):
def get_port():
freeport = FreePort(start=4000, stop=4000)
return freeport.port

def run():
self.assertEqual(get_port(), 4000)

freeport = FreePort(start=4000, stop=4000)
self.assertEqual(freeport.port, 4000)

from multiprocessing import Process
p = Process(target=run)
p.start()
p.join(0.1)

self.assertTrue(p.is_alive(), 'the process should find it hard to acquire a free port')

p.terminate()
p.join()

freeport.release()

def test_multithread_race_condition(self):
from multiprocessing.pool import ThreadPool
jobs = 100
def get_and_bind_freeport(*args):
freeport = FreePort(start=4000, stop=4000 + jobs - 1)
import time
time.sleep(1)
freeport.release() # needed because thread will not turn back the file descriptor
return freeport.port
p = ThreadPool(jobs)
ports = p.map(get_and_bind_freeport, range(jobs))
self.assertEqual(len(ports), len(set(ports)))

def test_multiprocess_race_condition(self):
from multiprocessing.pool import Pool
p = Pool(10)
ports = p.map(get_and_bind_freeport, range(10))
self.assertEqual(len(ports), len(set(ports)))
三月沙 wechat
扫描关注 wecatch 的公众号