如何在线程里跑协程
Mar 15, 2020
1 minute read

当使用不支持 async/await 协程的 web 框架时,想去使用协程来加快与其他服务的链接, 比如链接 Redis,发起大量网络请求等。所以我们需要在线程里跑 asyncio event loop。

创建用来跑 asyncio event loop 线程

import asyncio
import threading

class AsyncioEventLoopThread(threading.Thread):
    def __init__(self, *args, loop=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.loop = loop or asyncio.new_event_loop()
        self.running = False

    def run(self):
        self.running = True
        self.loop.run_forever()

    def run_coro(self, coro):
        return asyncio.run_coroutine_threadsafe(coro, loop=self.loop).result()

    def stop(self):
        self.loop.call_soon_threadsafe(self.loop.stop)
        self.join()
        self.running = False

因为 asyncio.run_coroutine_threadsafe 的返回值为 concurrent.futures.Future,可以直接执行其 result 方法,函数会等待直到协程执行完毕并返回结果。

下面的例子展示了如何使用 AsyncioEventLoopThread

async def hello_world():
    print("hello world")

async def make_request():
    await asyncio.sleep(1)

thr = AsyncioEventLoopThread()
thr.start()
try:
    thr.run_coro(hello_world())
    thr.run_coro(make_request())
finally:
    thr.stop()

注意,不要用两个不同的 event loop 运行同样的协程

在不同的协程中分享对象

为了达到这个目的,应该继承 AsyncioEventLoopThread 或直接修改它,来持用需要在协程间被分享的对象。

使用 contextvars 去在协程间分享对象时,在运行协程前必须要把需要分享的对象保存在 contextvars.Context 里面。

import contextvars
import aiohttp

var_session = contextvars.ContextVar('session')

class FetcherThread(AsyncioEventLoopThread):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._session = None
        self._event_session_created = threading.Event()

    def run_coro(self, coro):
        self._event_session_created.wait()
        var_session.set(self.session)
        return super().run_coro(coro)

    async def _create_session(self):
        self.session = aiohttp.ClientSession()

    async def _close_session(self):
        await self.session.close()

    def run(self):
        fut = asyncio.run_coroutine_threadsafe(self._create_session(), loop=self.loop)
        fut.add_done_callback(lambda _: self._event_session_created.set())
        super().run()

    def stop(self):
        self.run_coro(self._close_session())
        super().stop()

下面的例子展示了如何爬取 github.com 的网页源码。

async def make_request():
    session = var_session.get()
    async with session.get("https://github.com") as resp:
        resp.raise_for_status()
        return await resp.text()


thr = FetcherThread()
thr.start()
try:
    text = thr.run_coro(make_request())
    print(text)
finally:
    thr.stop()

参考




comments powered by Disqus