diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/requests.py | 150 |
1 files changed, 109 insertions, 41 deletions
diff --git a/g4f/requests.py b/g4f/requests.py index 736442e3..78acb9de 100644 --- a/g4f/requests.py +++ b/g4f/requests.py @@ -1,20 +1,24 @@ from __future__ import annotations -import json, sys +import warnings, json, asyncio + from functools import partialmethod +from asyncio import Future, Queue +from typing import AsyncGenerator -from aiohttp import StreamReader -from aiohttp.base_protocol import BaseProtocol +from curl_cffi.requests import AsyncSession, Response -from curl_cffi.requests import AsyncSession as BaseSession -from curl_cffi.requests import Response +import curl_cffi +is_newer_0_5_8 = hasattr(AsyncSession, "_set_cookies") or hasattr(curl_cffi.requests.Cookies, "get_cookies_for_curl") +is_newer_0_5_9 = hasattr(curl_cffi.AsyncCurl, "remove_handle") +is_newer_0_5_10 = hasattr(AsyncSession, "release_curl") class StreamResponse: - def __init__(self, inner: Response, content: StreamReader, request): + def __init__(self, inner: Response, queue: Queue): self.inner = inner - self.content = content - self.request = request + self.queue = queue + self.request = inner.request self.status_code = inner.status_code self.reason = inner.reason self.ok = inner.ok @@ -22,7 +26,7 @@ class StreamResponse: self.cookies = inner.cookies async def text(self) -> str: - content = await self.content.read() + content = await self.read() return content.decode() def raise_for_status(self): @@ -30,56 +34,120 @@ class StreamResponse: raise RuntimeError(f"HTTP Error {self.status_code}: {self.reason}") async def json(self, **kwargs): - return json.loads(await self.content.read(), **kwargs) + return json.loads(await self.read(), **kwargs) + + async def iter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None) -> AsyncGenerator[bytes]: + """ + Copied from: https://requests.readthedocs.io/en/latest/_modules/requests/models/ + which is under the License: Apache 2.0 + """ + pending = None + + async for chunk in self.iter_content( + chunk_size=chunk_size, decode_unicode=decode_unicode + ): + if pending is not None: + chunk = pending + chunk + if delimiter: + lines = chunk.split(delimiter) + else: + lines = chunk.splitlines() + if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]: + pending = lines.pop() + else: + pending = None + + for line in lines: + yield line + + if pending is not None: + yield pending + + async def iter_content(self, chunk_size=None, decode_unicode=False) -> As: + if chunk_size: + warnings.warn("chunk_size is ignored, there is no way to tell curl that.") + if decode_unicode: + raise NotImplementedError() + while True: + chunk = await self.queue.get() + if chunk is None: + return + yield chunk + + async def read(self) -> bytes: + return b"".join([chunk async for chunk in self.iter_content()]) class StreamRequest: def __init__(self, session: AsyncSession, method: str, url: str, **kwargs): self.session = session - self.loop = session.loop - self.content = StreamReader( - BaseProtocol(session.loop), - sys.maxsize, - loop=session.loop - ) + self.loop = session.loop if session.loop else asyncio.get_running_loop() + self.queue = Queue() self.method = method self.url = url self.options = kwargs + self.handle = None - def on_content(self, data): + def _on_content(self, data): if not self.enter.done(): self.enter.set_result(None) - self.content.feed_data(data) + self.queue.put_nowait(data) - def on_done(self, task): - self.content.feed_eof() - self.curl.clean_after_perform() - self.curl.reset() - self.session.push_curl(self.curl) + def _on_done(self, task: Future): + if not self.enter.done(): + self.enter.set_result(None) + self.queue.put_nowait(None) - async def __aenter__(self) -> StreamResponse: + self.loop.call_soon(self.session.release_curl, self.curl) + + async def fetch(self) -> StreamResponse: + if self.handle: + raise RuntimeError("Request already started") self.curl = await self.session.pop_curl() self.enter = self.loop.create_future() - request, _, header_buffer = self.session._set_curl_options( - self.curl, - self.method, - self.url, - content_callback=self.on_content, - **self.options - ) - await self.session.acurl.add_handle(self.curl, False) - self.handle = self.session.acurl._curl2future[self.curl] - self.handle.add_done_callback(self.on_done) + if is_newer_0_5_10: + request, _, header_buffer, _, _ = self.session._set_curl_options( + self.curl, + self.method, + self.url, + content_callback=self._on_content, + **self.options + ) + else: + request, _, header_buffer = self.session._set_curl_options( + self.curl, + self.method, + self.url, + content_callback=self._on_content, + **self.options + ) + if is_newer_0_5_9: + self.handle = self.session.acurl.add_handle(self.curl) + else: + await self.session.acurl.add_handle(self.curl, False) + self.handle = self.session.acurl._curl2future[self.curl] + self.handle.add_done_callback(self._on_done) + # Wait for headers await self.enter + # Raise exceptions + if self.handle.done(): + self.handle.result() + if is_newer_0_5_8: + response = self.session._parse_response(self.curl, _, header_buffer) + response.request = request + else: + response = self.session._parse_response(self.curl, request, _, header_buffer) return StreamResponse( - self.session._parse_response(self.curl, request, _, header_buffer), - self.content, - request + response, + self.queue ) - - async def __aexit__(self, exc_type, exc, tb): - pass + + async def __aenter__(self) -> StreamResponse: + return await self.fetch() + + async def __aexit__(self, *args): + self.session.release_curl(self.curl) -class AsyncSession(BaseSession): +class StreamSession(AsyncSession): def request( self, method: str, |