From ea8d6b847a0e656cc5583948c5745592adda7103 Mon Sep 17 00:00:00 2001
From: Heiner Lohaus
Date: Sat, 13 Jan 2024 15:37:36 +0100
Subject: Support upload image in gui Add image upload to OpenaiChat Add image
response to OpenaiChat Improve ChatGPT Plus Support Remove unused
requirements
---
etc/testing/test_chat_completion.py | 11 +-
g4f/Provider/Bing.py | 18 +-
g4f/Provider/base_provider.py | 7 +-
g4f/Provider/bing/conversation.py | 7 +-
g4f/Provider/bing/create_images.py | 18 +-
g4f/Provider/bing/upload_image.py | 166 +++++-----------
g4f/Provider/create_images.py | 10 +-
g4f/Provider/needs_auth/OpenaiChat.py | 344 ++++++++++++++++++++++++----------
g4f/__init__.py | 2 +-
g4f/base_provider.py | 2 +-
g4f/gui/client/css/style.css | 34 +++-
g4f/gui/client/html/index.html | 93 ++-------
g4f/gui/client/js/chat.v1.js | 96 +++++-----
g4f/gui/server/backend.py | 58 ++++--
g4f/image.py | 116 ++++++++++++
g4f/requests.py | 31 +--
g4f/typing.py | 4 +-
g4f/version.py | 26 +--
requirements.txt | 4 -
setup.py | 4 -
20 files changed, 610 insertions(+), 441 deletions(-)
create mode 100644 g4f/image.py
diff --git a/etc/testing/test_chat_completion.py b/etc/testing/test_chat_completion.py
index 7058ab4c..615c8be0 100644
--- a/etc/testing/test_chat_completion.py
+++ b/etc/testing/test_chat_completion.py
@@ -7,10 +7,9 @@ import g4f, asyncio
print("create:", end=" ", flush=True)
for response in g4f.ChatCompletion.create(
- model=g4f.models.gpt_4_32k_0613,
- provider=g4f.Provider.Aivvm,
+ model=g4f.models.default,
+ provider=g4f.Provider.Bing,
messages=[{"role": "user", "content": "write a poem about a tree"}],
- temperature=0.1,
stream=True
):
print(response, end="", flush=True)
@@ -18,10 +17,10 @@ print()
async def run_async():
response = await g4f.ChatCompletion.create_async(
- model=g4f.models.gpt_35_turbo_16k_0613,
- provider=g4f.Provider.GptGod,
+ model=g4f.models.default,
+ provider=g4f.Provider.Bing,
messages=[{"role": "user", "content": "hello!"}],
)
print("create_async:", response)
-# asyncio.run(run_async())
+asyncio.run(run_async())
diff --git a/g4f/Provider/Bing.py b/g4f/Provider/Bing.py
index b0949397..da9b0172 100644
--- a/g4f/Provider/Bing.py
+++ b/g4f/Provider/Bing.py
@@ -8,11 +8,10 @@ import time
from urllib import parse
from aiohttp import ClientSession, ClientTimeout
-from ..typing import AsyncResult, Messages
+from ..typing import AsyncResult, Messages, ImageType
from .base_provider import AsyncGeneratorProvider
-from ..webdriver import get_browser, get_driver_cookies
from .bing.upload_image import upload_image
-from .bing.create_images import create_images, format_images_markdown, wait_for_login
+from .bing.create_images import create_images, format_images_markdown
from .bing.conversation import Conversation, create_conversation, delete_conversation
class Tones():
@@ -34,7 +33,7 @@ class Bing(AsyncGeneratorProvider):
timeout: int = 900,
cookies: dict = None,
tone: str = Tones.balanced,
- image: str = None,
+ image: ImageType = None,
web_search: bool = False,
**kwargs
) -> AsyncResult:
@@ -247,7 +246,7 @@ def create_message(
async def stream_generate(
prompt: str,
tone: str,
- image: str = None,
+ image: ImageType = None,
context: str = None,
proxy: str = None,
cookies: dict = None,
@@ -315,14 +314,7 @@ async def stream_generate(
result = response['item']['result']
if result.get('error'):
if result["value"] == "CaptchaChallenge":
- driver = get_browser(proxy=proxy)
- try:
- wait_for_login(driver)
- cookies = get_driver_cookies(driver)
- finally:
- driver.quit()
- async for chunk in stream_generate(prompt, tone, image, context, proxy, cookies, web_search, gpt4_turbo, timeout):
- yield chunk
+ raise Exception(f"{result['value']}: Use other cookies or/and ip address")
else:
raise Exception(f"{result['value']}: {result['message']}")
return
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index 6da7f6c6..e7e88841 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod
from inspect import signature, Parameter
from .helper import get_event_loop, get_cookies, format_prompt
-from ..typing import CreateResult, AsyncResult, Messages, Union
+from ..typing import CreateResult, AsyncResult, Messages
from ..base_provider import BaseProvider
if sys.version_info < (3, 10):
@@ -77,8 +77,7 @@ class AbstractProvider(BaseProvider):
continue
if args:
args += ", "
- args += "\n"
- args += " " + name
+ args += "\n " + name
if name != "model" and param.annotation is not Parameter.empty:
args += f": {get_type_name(param.annotation)}"
if param.default == "":
@@ -156,7 +155,7 @@ class AsyncGeneratorProvider(AsyncProvider):
messages,
stream=False,
**kwargs
- )
+ ) if not isinstance(chunk, Exception)
])
@staticmethod
diff --git a/g4f/Provider/bing/conversation.py b/g4f/Provider/bing/conversation.py
index ef45cd82..9e011c26 100644
--- a/g4f/Provider/bing/conversation.py
+++ b/g4f/Provider/bing/conversation.py
@@ -10,7 +10,10 @@ class Conversation():
async def create_conversation(session: ClientSession, proxy: str = None) -> Conversation:
url = 'https://www.bing.com/turing/conversation/create?bundleVersion=1.1199.4'
async with session.get(url, proxy=proxy) as response:
- data = await response.json()
+ try:
+ data = await response.json()
+ except:
+ raise RuntimeError(f"Response: {await response.text()}")
conversationId = data.get('conversationId')
clientId = data.get('clientId')
@@ -26,7 +29,7 @@ async def list_conversations(session: ClientSession) -> list:
response = await response.json()
return response["chats"]
-async def delete_conversation(session: ClientSession, conversation: Conversation, proxy: str = None) -> list:
+async def delete_conversation(session: ClientSession, conversation: Conversation, proxy: str = None) -> bool:
url = "https://sydney.bing.com/sydney/DeleteSingleConversation"
json = {
"conversationId": conversation.conversationId,
diff --git a/g4f/Provider/bing/create_images.py b/g4f/Provider/bing/create_images.py
index b203a0dc..a1ecace3 100644
--- a/g4f/Provider/bing/create_images.py
+++ b/g4f/Provider/bing/create_images.py
@@ -9,6 +9,7 @@ from ..create_images import CreateImagesProvider
from ..helper import get_cookies, get_event_loop
from ...webdriver import WebDriver, get_driver_cookies, get_browser
from ...base_provider import ProviderType
+from ...image import format_images_markdown
BING_URL = "https://www.bing.com"
@@ -23,6 +24,7 @@ def wait_for_login(driver: WebDriver, timeout: int = 1200) -> None:
raise RuntimeError("Timeout error")
value = driver.get_cookie("_U")
if value:
+ time.sleep(1)
return
time.sleep(0.5)
@@ -62,7 +64,8 @@ async def create_images(session: ClientSession, prompt: str, proxy: str = None,
errors = [
"this prompt is being reviewed",
"this prompt has been blocked",
- "we're working hard to offer image creator in more languages"
+ "we're working hard to offer image creator in more languages",
+ "we can't create your images right now"
]
text = (await response.text()).lower()
for error in errors:
@@ -72,7 +75,7 @@ async def create_images(session: ClientSession, prompt: str, proxy: str = None,
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=3&FORM=GENCRE"
async with session.post(url, allow_redirects=False, proxy=proxy, timeout=timeout) as response:
if response.status != 302:
- raise RuntimeError(f"Create images failed. Status Code: {response.status}")
+ raise RuntimeError(f"Create images failed. Code: {response.status}")
redirect_url = response.headers["Location"].replace("&nfy=1", "")
redirect_url = f"{BING_URL}{redirect_url}"
@@ -84,10 +87,10 @@ async def create_images(session: ClientSession, prompt: str, proxy: str = None,
start_time = time.time()
while True:
if time.time() - start_time > timeout:
- raise RuntimeError(f"Timeout error after {timeout} seconds")
+ raise RuntimeError(f"Timeout error after {timeout} sec")
async with session.get(polling_url) as response:
if response.status != 200:
- raise RuntimeError(f"Polling images faild. Status Code: {response.status}")
+ raise RuntimeError(f"Polling images faild. Code: {response.status}")
text = await response.text()
if not text:
await asyncio.sleep(1)
@@ -119,13 +122,6 @@ def read_images(text: str) -> list:
raise RuntimeError("No images found")
return images
-def format_images_markdown(images: list, prompt: str) -> str:
- images = [f"[![#{idx+1} {prompt}]({image}?w=200&h=200)]({image})" for idx, image in enumerate(images)]
- images = "\n".join(images)
- start_flag = "\n"
- end_flag = "\n"
- return f"\n{start_flag}{images}\n{end_flag}\n"
-
async def create_images_markdown(cookies: dict, prompt: str, proxy: str = None) -> str:
session = create_session(cookies)
try:
diff --git a/g4f/Provider/bing/upload_image.py b/g4f/Provider/bing/upload_image.py
index 329e6df4..a7413207 100644
--- a/g4f/Provider/bing/upload_image.py
+++ b/g4f/Provider/bing/upload_image.py
@@ -3,70 +3,59 @@ from __future__ import annotations
import string
import random
import json
-import re
-import io
-import base64
import numpy as np
-from PIL import Image
+from ...typing import ImageType
from aiohttp import ClientSession
+from ...image import to_image, process_image, to_base64
+
+image_config = {
+ "maxImagePixels": 360000,
+ "imageCompressionRate": 0.7,
+ "enableFaceBlurDebug": 0,
+}
async def upload_image(
session: ClientSession,
- image: str,
+ image: ImageType,
tone: str,
proxy: str = None
-):
- try:
- image_config = {
- "maxImagePixels": 360000,
- "imageCompressionRate": 0.7,
- "enableFaceBlurDebug": 0,
- }
- is_data_uri_an_image(image)
- img_binary_data = extract_data_uri(image)
- is_accepted_format(img_binary_data)
- img = Image.open(io.BytesIO(img_binary_data))
- width, height = img.size
- max_image_pixels = image_config['maxImagePixels']
- if max_image_pixels / (width * height) < 1:
- new_width = int(width * np.sqrt(max_image_pixels / (width * height)))
- new_height = int(height * np.sqrt(max_image_pixels / (width * height)))
- else:
- new_width = width
- new_height = height
- try:
- orientation = get_orientation(img)
- except Exception:
- orientation = None
- new_img = process_image(orientation, img, new_width, new_height)
- new_img_binary_data = compress_image_to_base64(new_img, image_config['imageCompressionRate'])
- data, boundary = build_image_upload_api_payload(new_img_binary_data, tone)
- headers = session.headers.copy()
- headers["content-type"] = f'multipart/form-data; boundary={boundary}'
- headers["referer"] = 'https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx'
- headers["origin"] = 'https://www.bing.com'
- async with session.post("https://www.bing.com/images/kblob", data=data, headers=headers, proxy=proxy) as response:
- if response.status != 200:
- raise RuntimeError("Failed to upload image.")
- image_info = await response.json()
- if not image_info.get('blobId'):
- raise RuntimeError("Failed to parse image info.")
- result = {'bcid': image_info.get('blobId', "")}
- result['blurredBcid'] = image_info.get('processedBlobId', "")
- if result['blurredBcid'] != "":
- result["imageUrl"] = "https://www.bing.com/images/blob?bcid=" + result['blurredBcid']
- elif result['bcid'] != "":
- result["imageUrl"] = "https://www.bing.com/images/blob?bcid=" + result['bcid']
- result['originalImageUrl'] = (
- "https://www.bing.com/images/blob?bcid="
- + result['blurredBcid']
- if image_config["enableFaceBlurDebug"]
- else "https://www.bing.com/images/blob?bcid="
- + result['bcid']
- )
- return result
- except Exception as e:
- raise RuntimeError(f"Upload image failed: {e}")
+) -> dict:
+ image = to_image(image)
+ width, height = image.size
+ max_image_pixels = image_config['maxImagePixels']
+ if max_image_pixels / (width * height) < 1:
+ new_width = int(width * np.sqrt(max_image_pixels / (width * height)))
+ new_height = int(height * np.sqrt(max_image_pixels / (width * height)))
+ else:
+ new_width = width
+ new_height = height
+ new_img = process_image(image, new_width, new_height)
+ new_img_binary_data = to_base64(new_img, image_config['imageCompressionRate'])
+ data, boundary = build_image_upload_api_payload(new_img_binary_data, tone)
+ headers = session.headers.copy()
+ headers["content-type"] = f'multipart/form-data; boundary={boundary}'
+ headers["referer"] = 'https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx'
+ headers["origin"] = 'https://www.bing.com'
+ async with session.post("https://www.bing.com/images/kblob", data=data, headers=headers, proxy=proxy) as response:
+ if response.status != 200:
+ raise RuntimeError("Failed to upload image.")
+ image_info = await response.json()
+ if not image_info.get('blobId'):
+ raise RuntimeError("Failed to parse image info.")
+ result = {'bcid': image_info.get('blobId', "")}
+ result['blurredBcid'] = image_info.get('processedBlobId', "")
+ if result['blurredBcid'] != "":
+ result["imageUrl"] = "https://www.bing.com/images/blob?bcid=" + result['blurredBcid']
+ elif result['bcid'] != "":
+ result["imageUrl"] = "https://www.bing.com/images/blob?bcid=" + result['bcid']
+ result['originalImageUrl'] = (
+ "https://www.bing.com/images/blob?bcid="
+ + result['blurredBcid']
+ if image_config["enableFaceBlurDebug"]
+ else "https://www.bing.com/images/blob?bcid="
+ + result['bcid']
+ )
+ return result
def build_image_upload_api_payload(image_bin: str, tone: str):
@@ -98,65 +87,4 @@ def build_image_upload_api_payload(image_bin: str, tone: str):
+ boundary
+ "--\r\n"
)
- return data, boundary
-
-def is_data_uri_an_image(data_uri: str):
- # Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
- if not re.match(r'data:image/(\w+);base64,', data_uri):
- raise ValueError("Invalid data URI image.")
- # Extract the image format from the data URI
- image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1)
- # Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
- if image_format.lower() not in ['jpeg', 'jpg', 'png', 'gif']:
- raise ValueError("Invalid image format (from mime file type).")
-
-def is_accepted_format(binary_data: bytes) -> bool:
- if binary_data.startswith(b'\xFF\xD8\xFF'):
- pass # It's a JPEG image
- elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
- pass # It's a PNG image
- elif binary_data.startswith(b'GIF87a') or binary_data.startswith(b'GIF89a'):
- pass # It's a GIF image
- elif binary_data.startswith(b'\x89JFIF') or binary_data.startswith(b'JFIF\x00'):
- pass # It's a JPEG image
- elif binary_data.startswith(b'\xFF\xD8'):
- pass # It's a JPEG image
- elif binary_data.startswith(b'RIFF') and binary_data[8:12] == b'WEBP':
- pass # It's a WebP image
- else:
- raise ValueError("Invalid image format (from magic code).")
-
-def extract_data_uri(data_uri: str) -> bytes:
- data = data_uri.split(",")[1]
- data = base64.b64decode(data)
- return data
-
-def get_orientation(data: bytes) -> int:
- if data[:2] != b'\xFF\xD8':
- raise Exception('NotJpeg')
- with Image.open(data) as img:
- exif_data = img._getexif()
- if exif_data is not None:
- orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
- if orientation is not None:
- return orientation
-
-def process_image(orientation: int, img: Image.Image, new_width: int, new_height: int) -> Image.Image:
- # Initialize the canvas
- new_img = Image.new("RGB", (new_width, new_height), color="#FFFFFF")
- if orientation:
- if orientation > 4:
- img = img.transpose(Image.FLIP_LEFT_RIGHT)
- if orientation in [3, 4]:
- img = img.transpose(Image.ROTATE_180)
- if orientation in [5, 6]:
- img = img.transpose(Image.ROTATE_270)
- if orientation in [7, 8]:
- img = img.transpose(Image.ROTATE_90)
- new_img.paste(img, (0, 0))
- return new_img
-
-def compress_image_to_base64(image: Image.Image, compression_rate: float) -> str:
- output_buffer = io.BytesIO()
- image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
- return base64.b64encode(output_buffer.getvalue()).decode('utf-8')
\ No newline at end of file
+ return data, boundary
\ No newline at end of file
diff --git a/g4f/Provider/create_images.py b/g4f/Provider/create_images.py
index 29f88a80..f8a0442d 100644
--- a/g4f/Provider/create_images.py
+++ b/g4f/Provider/create_images.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import re
import asyncio
+from .. import debug
from ..typing import CreateResult, Messages
from ..base_provider import BaseProvider, ProviderType
@@ -26,12 +27,11 @@ class CreateImagesProvider(BaseProvider):
self.create_images = create_images
self.create_images_async = create_async
self.system_message = system_message
+ self.include_placeholder = include_placeholder
self.__name__ = provider.__name__
+ self.url = provider.url
self.working = provider.working
self.supports_stream = provider.supports_stream
- self.include_placeholder = include_placeholder
- if hasattr(provider, "url"):
- self.url = provider.url
def create_completion(
self,
@@ -54,6 +54,8 @@ class CreateImagesProvider(BaseProvider):
yield start
if self.include_placeholder:
yield placeholder
+ if debug.logging:
+ print(f"Create images with prompt: {prompt}")
yield from self.create_images(prompt)
if append:
yield append
@@ -76,6 +78,8 @@ class CreateImagesProvider(BaseProvider):
placeholders = []
for placeholder, prompt in matches:
if placeholder not in placeholders:
+ if debug.logging:
+ print(f"Create images with prompt: {prompt}")
results.append(self.create_images_async(prompt))
placeholders.append(placeholder)
results = await asyncio.gather(*results)
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index 4651955c..4b11aeaf 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -2,17 +2,18 @@ from __future__ import annotations
import uuid, json, asyncio, os
from py_arkose_generator.arkose import get_values_for_request
-from asyncstdlib.itertools import tee
from async_property import async_cached_property
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from ..base_provider import AsyncGeneratorProvider
-from ..helper import get_event_loop, format_prompt, get_cookies
-from ...webdriver import get_browser
+from ..helper import format_prompt, get_cookies
+from ...webdriver import get_browser, get_driver_cookies
from ...typing import AsyncResult, Messages
from ...requests import StreamSession
+from ...image import to_image, to_bytes, ImageType, ImageResponse
+from ... import debug
models = {
"gpt-3.5": "text-davinci-002-render-sha",
@@ -28,6 +29,7 @@ class OpenaiChat(AsyncGeneratorProvider):
supports_gpt_35_turbo = True
supports_gpt_4 = True
_cookies: dict = {}
+ _default_model: str = None
@classmethod
async def create(
@@ -39,6 +41,7 @@ class OpenaiChat(AsyncGeneratorProvider):
action: str = "next",
conversation_id: str = None,
parent_id: str = None,
+ image: ImageType = None,
**kwargs
) -> Response:
if prompt:
@@ -53,16 +56,120 @@ class OpenaiChat(AsyncGeneratorProvider):
action=action,
conversation_id=conversation_id,
parent_id=parent_id,
+ image=image,
response_fields=True,
**kwargs
)
return Response(
generator,
- await anext(generator),
action,
messages,
kwargs
)
+
+ @classmethod
+ async def upload_image(
+ cls,
+ session: StreamSession,
+ headers: dict,
+ image: ImageType
+ ) -> ImageResponse:
+ image = to_image(image)
+ extension = image.format.lower()
+ data_bytes = to_bytes(image)
+ data = {
+ "file_name": f"{image.width}x{image.height}.{extension}",
+ "file_size": len(data_bytes),
+ "use_case": "multimodal"
+ }
+ async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response:
+ response.raise_for_status()
+ image_data = {
+ **data,
+ **await response.json(),
+ "mime_type": f"image/{extension}",
+ "extension": extension,
+ "height": image.height,
+ "width": image.width
+ }
+ async with session.put(
+ image_data["upload_url"],
+ data=data_bytes,
+ headers={
+ "Content-Type": image_data["mime_type"],
+ "x-ms-blob-type": "BlockBlob"
+ }
+ ) as response:
+ response.raise_for_status()
+ async with session.post(
+ f"{cls.url}/backend-api/files/{image_data['file_id']}/uploaded",
+ json={},
+ headers=headers
+ ) as response:
+ response.raise_for_status()
+ download_url = (await response.json())["download_url"]
+ return ImageResponse(download_url, image_data["file_name"], image_data)
+
+ @classmethod
+ async def get_default_model(cls, session: StreamSession, headers: dict):
+ if cls._default_model:
+ model = cls._default_model
+ else:
+ async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
+ data = await response.json()
+ if "categories" in data:
+ model = data["categories"][-1]["default_model"]
+ else:
+ RuntimeError(f"Response: {data}")
+ cls._default_model = model
+ return model
+
+ @classmethod
+ def create_messages(cls, prompt: str, image_response: ImageResponse = None):
+ if not image_response:
+ content = {"content_type": "text", "parts": [prompt]}
+ else:
+ content = {
+ "content_type": "multimodal_text",
+ "parts": [{
+ "asset_pointer": f"file-service://{image_response.get('file_id')}",
+ "height": image_response.get("height"),
+ "size_bytes": image_response.get("file_size"),
+ "width": image_response.get("width"),
+ }, prompt]
+ }
+ messages = [{
+ "id": str(uuid.uuid4()),
+ "author": {"role": "user"},
+ "content": content,
+ }]
+ if image_response:
+ messages[0]["metadata"] = {
+ "attachments": [{
+ "height": image_response.get("height"),
+ "id": image_response.get("file_id"),
+ "mimeType": image_response.get("mime_type"),
+ "name": image_response.get("file_name"),
+ "size": image_response.get("file_size"),
+ "width": image_response.get("width"),
+ }]
+ }
+ return messages
+
+ @classmethod
+ async def get_image_response(cls, session: StreamSession, headers: dict, line: dict):
+ if "parts" in line["message"]["content"]:
+ part = line["message"]["content"]["parts"][0]
+ if "asset_pointer" in part and part["metadata"]:
+ file_id = part["asset_pointer"].split("file-service://", 1)[1]
+ prompt = part["metadata"]["dalle"]["prompt"]
+ async with session.get(
+ f"{cls.url}/backend-api/files/{file_id}/download",
+ headers=headers
+ ) as response:
+ response.raise_for_status()
+ download_url = (await response.json())["download_url"]
+ return ImageResponse(download_url, prompt)
@classmethod
async def create_async_generator(
@@ -78,13 +185,12 @@ class OpenaiChat(AsyncGeneratorProvider):
action: str = "next",
conversation_id: str = None,
parent_id: str = None,
+ image: ImageType = None,
response_fields: bool = False,
**kwargs
) -> AsyncResult:
- if not model:
- model = "gpt-3.5"
- elif model not in models:
- raise ValueError(f"Model are not supported: {model}")
+ if model in models:
+ model = models[model]
if not parent_id:
parent_id = str(uuid.uuid4())
if not cookies:
@@ -98,115 +204,131 @@ class OpenaiChat(AsyncGeneratorProvider):
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield f"Please login: [ChatGPT]({login_url})\n\n"
- cls._cookies["access_token"] = access_token = await cls.browse_access_token(proxy)
+ access_token, cookies = cls.browse_access_token(proxy)
+ cls._cookies = cookies
headers = {
- "Accept": "text/event-stream",
"Authorization": f"Bearer {access_token}",
}
async with StreamSession(
proxies={"https": proxy},
impersonate="chrome110",
- headers=headers,
timeout=timeout,
cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"])
) as session:
+ if not model:
+ model = await cls.get_default_model(session, headers)
+ try:
+ image_response = None
+ if image:
+ image_response = await cls.upload_image(session, headers, image)
+ yield image_response
+ except Exception as e:
+ yield e
end_turn = EndTurn()
while not end_turn.is_end:
data = {
"action": action,
- "arkose_token": await get_arkose_token(proxy, timeout),
+ "arkose_token": await cls.get_arkose_token(session),
"conversation_id": conversation_id,
"parent_message_id": parent_id,
- "model": models[model],
+ "model": model,
"history_and_training_disabled": history_disabled and not auto_continue,
}
if action != "continue":
prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"]
- data["messages"] = [{
- "id": str(uuid.uuid4()),
- "author": {"role": "user"},
- "content": {"content_type": "text", "parts": [prompt]},
- }]
- async with session.post(f"{cls.url}/backend-api/conversation", json=data) as response:
+ data["messages"] = cls.create_messages(prompt, image_response)
+ async with session.post(
+ f"{cls.url}/backend-api/conversation",
+ json=data,
+ headers={"Accept": "text/event-stream", **headers}
+ ) as response:
try:
response.raise_for_status()
except:
- raise RuntimeError(f"Error {response.status_code}: {await response.text()}")
- last_message = 0
- async for line in response.iter_lines():
- if not line.startswith(b"data: "):
- continue
- line = line[6:]
- if line == b"[DONE]":
- break
- try:
- line = json.loads(line)
- except:
- continue
- if "message" not in line:
- continue
- if "error" in line and line["error"]:
- raise RuntimeError(line["error"])
- if "message_type" not in line["message"]["metadata"]:
- continue
- if line["message"]["author"]["role"] != "assistant":
- continue
- if line["message"]["metadata"]["message_type"] in ("next", "continue", "variant"):
- conversation_id = line["conversation_id"]
- parent_id = line["message"]["id"]
- if response_fields:
- response_fields = False
- yield ResponseFields(conversation_id, parent_id, end_turn)
- new_message = line["message"]["content"]["parts"][0]
- yield new_message[last_message:]
- last_message = len(new_message)
- if "finish_details" in line["message"]["metadata"]:
- if line["message"]["metadata"]["finish_details"]["type"] == "stop":
- end_turn.end()
+ raise RuntimeError(f"Response {response.status_code}: {await response.text()}")
+ try:
+ last_message: int = 0
+ async for line in response.iter_lines():
+ if not line.startswith(b"data: "):
+ continue
+ elif line.startswith(b"data: [DONE]"):
+ break
+ try:
+ line = json.loads(line[6:])
+ except:
+ continue
+ if "message" not in line:
+ continue
+ if "error" in line and line["error"]:
+ raise RuntimeError(line["error"])
+ if "message_type" not in line["message"]["metadata"]:
+ continue
+ try:
+ image_response = await cls.get_image_response(session, headers, line)
+ if image_response:
+ yield image_response
+ except Exception as e:
+ yield e
+ if line["message"]["author"]["role"] != "assistant":
+ continue
+ if line["message"]["metadata"]["message_type"] in ("next", "continue", "variant"):
+ conversation_id = line["conversation_id"]
+ parent_id = line["message"]["id"]
+ if response_fields:
+ response_fields = False
+ yield ResponseFields(conversation_id, parent_id, end_turn)
+ if "parts" in line["message"]["content"]:
+ new_message = line["message"]["content"]["parts"][0]
+ if len(new_message) > last_message:
+ yield new_message[last_message:]
+ last_message = len(new_message)
+ if "finish_details" in line["message"]["metadata"]:
+ if line["message"]["metadata"]["finish_details"]["type"] == "stop":
+ end_turn.end()
+ break
+ except Exception as e:
+ yield e
if not auto_continue:
break
action = "continue"
await asyncio.sleep(5)
+ if history_disabled:
+ async with session.patch(
+ f"{cls.url}/backend-api/conversation/{conversation_id}",
+ json={"is_visible": False},
+ headers=headers
+ ) as response:
+ response.raise_for_status()
@classmethod
- async def browse_access_token(cls, proxy: str = None) -> str:
- def browse() -> str:
- driver = get_browser(proxy=proxy)
- try:
- driver.get(f"{cls.url}/")
- WebDriverWait(driver, 1200).until(
- EC.presence_of_element_located((By.ID, "prompt-textarea"))
- )
- javascript = """
+ def browse_access_token(cls, proxy: str = None) -> tuple[str, dict]:
+ driver = get_browser(proxy=proxy)
+ try:
+ driver.get(f"{cls.url}/")
+ WebDriverWait(driver, 1200).until(
+ EC.presence_of_element_located((By.ID, "prompt-textarea"))
+ )
+ javascript = """
access_token = (await (await fetch('/api/auth/session')).json())['accessToken'];
expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7); // One week
document.cookie = 'access_token=' + access_token + ';expires=' + expires.toUTCString() + ';path=/';
return access_token;
"""
- return driver.execute_script(javascript)
- finally:
- driver.quit()
- loop = get_event_loop()
- return await loop.run_in_executor(
- None,
- browse
- )
-
-async def get_arkose_token(proxy: str = None, timeout: int = None) -> str:
- config = {
- "pkey": "3D86FBBA-9D22-402A-B512-3420086BA6CC",
- "surl": "https://tcr9i.chat.openai.com",
- "headers": {
- "User-Agent": 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36'
- },
- "site": "https://chat.openai.com",
- }
- args_for_request = get_values_for_request(config)
- async with StreamSession(
- proxies={"https": proxy},
- impersonate="chrome107",
- timeout=timeout
- ) as session:
+ return driver.execute_script(javascript), get_driver_cookies(driver)
+ finally:
+ driver.quit()
+
+ @classmethod
+ async def get_arkose_token(cls, session: StreamSession) -> str:
+ config = {
+ "pkey": "3D86FBBA-9D22-402A-B512-3420086BA6CC",
+ "surl": "https://tcr9i.chat.openai.com",
+ "headers": {
+ "User-Agent": 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36'
+ },
+ "site": cls.url,
+ }
+ args_for_request = get_values_for_request(config)
async with session.post(**args_for_request) as response:
response.raise_for_status()
decoded_json = await response.json()
@@ -236,23 +358,47 @@ class Response():
def __init__(
self,
generator: AsyncResult,
- fields: ResponseFields,
action: str,
messages: Messages,
options: dict
):
- self.aiter, self.copy = tee(generator)
- self.fields = fields
- self.action = action
+ self._generator = generator
+ self.action: str = action
+ self.is_end: bool = False
+ self._message = None
self._messages = messages
self._options = options
+ self._fields = None
+
+ async def generator(self):
+ if self._generator:
+ self._generator = None
+ chunks = []
+ async for chunk in self._generator:
+ if isinstance(chunk, ResponseFields):
+ self._fields = chunk
+ else:
+ yield chunk
+ chunks.append(str(chunk))
+ self._message = "".join(chunks)
+ if not self._fields:
+ raise RuntimeError("Missing response fields")
+ self.is_end = self._fields._end_turn.is_end
def __aiter__(self):
- return self.aiter
+ return self.generator()
@async_cached_property
async def message(self) -> str:
- return "".join([chunk async for chunk in self.copy])
+ [_ async for _ in self.generator()]
+ return self._message
+
+ async def get_fields(self):
+ [_ async for _ in self.generator()]
+ return {
+ "conversation_id": self._fields.conversation_id,
+ "parent_id": self._fields.message_id,
+ }
async def next(self, prompt: str, **kwargs) -> Response:
return await OpenaiChat.create(
@@ -260,20 +406,19 @@ class Response():
prompt=prompt,
messages=await self.messages,
action="next",
- conversation_id=self.fields.conversation_id,
- parent_id=self.fields.message_id,
+ **await self.get_fields(),
**kwargs
)
async def do_continue(self, **kwargs) -> Response:
- if self.end_turn:
+ fields = await self.get_fields()
+ if self.is_end:
raise RuntimeError("Can't continue message. Message already finished.")
return await OpenaiChat.create(
**self._options,
messages=await self.messages,
action="continue",
- conversation_id=self.fields.conversation_id,
- parent_id=self.fields.message_id,
+ **fields,
**kwargs
)
@@ -284,8 +429,7 @@ class Response():
**self._options,
messages=self._messages,
action="variant",
- conversation_id=self.fields.conversation_id,
- parent_id=self.fields.message_id,
+ **await self.get_fields(),
**kwargs
)
@@ -295,8 +439,4 @@ class Response():
messages.append({
"role": "assistant", "content": await self.message
})
- return messages
-
- @property
- def end_turn(self):
- return self.fields._end_turn.is_end
\ No newline at end of file
+ return messages
\ No newline at end of file
diff --git a/g4f/__init__.py b/g4f/__init__.py
index dc7808f9..68f9ccf6 100644
--- a/g4f/__init__.py
+++ b/g4f/__init__.py
@@ -17,7 +17,7 @@ def get_model_and_provider(model : Union[Model, str],
ignore_stream: bool = False) -> tuple[str, ProviderType]:
if debug.version_check:
debug.version_check = False
- version.utils.check_pypi_version()
+ version.utils.check_version()
if isinstance(provider, str):
if provider in ProviderUtils.convert:
diff --git a/g4f/base_provider.py b/g4f/base_provider.py
index 84cbc384..1863f6bc 100644
--- a/g4f/base_provider.py
+++ b/g4f/base_provider.py
@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from .typing import Messages, CreateResult, Union
class BaseProvider(ABC):
- url: str
+ url: str = None
working: bool = False
needs_auth: bool = False
supports_stream: bool = False
diff --git a/g4f/gui/client/css/style.css b/g4f/gui/client/css/style.css
index 3e2d6d6f..59464272 100644
--- a/g4f/gui/client/css/style.css
+++ b/g4f/gui/client/css/style.css
@@ -217,7 +217,6 @@ body {
}
.message {
-
width: 100%;
overflow-wrap: break-word;
display: flex;
@@ -302,10 +301,14 @@ body {
line-height: 1.3;
color: var(--colour-3);
}
-.message .content pre {
+.message .content pre{
white-space: pre-wrap;
}
+.message .content img{
+ max-width: 400px;
+}
+
.message .user i {
position: absolute;
bottom: -6px;
@@ -401,13 +404,28 @@ body {
display: none;
}
-input[type="checkbox"] {
+#image {
+ display: none;
+}
+
+label[for="image"]:has(> input:valid){
+ color: var(--accent);
+}
+
+label[for="image"] {
+ cursor: pointer;
+ position: absolute;
+ top: 10px;
+ left: 10px;
+}
+
+.buttons input[type="checkbox"] {
height: 0;
width: 0;
display: none;
}
-label {
+.buttons label {
cursor: pointer;
text-indent: -9999px;
width: 50px;
@@ -424,7 +442,7 @@ label {
transition: 0.33s;
}
-label:after {
+.buttons label:after {
content: "";
position: absolute;
top: 50%;
@@ -437,11 +455,11 @@ label:after {
transition: 0.33s;
}
-input:checked+label {
- background: var(--blur-border);
+.buttons input:checked+label {
+ background: var(--accent);
}
-input:checked+label:after {
+.buttons input:checked+label:after {
left: calc(100% - 5px - 20px);
}
diff --git a/g4f/gui/client/html/index.html b/g4f/gui/client/html/index.html
index bc41bd45..3f2bb0c0 100644
--- a/g4f/gui/client/html/index.html
+++ b/g4f/gui/client/html/index.html
@@ -36,7 +36,8 @@
#message-input {
margin-right: 30px;
- height: 80px;
+ height: 82px;
+ margin-left: 20px;
}
#message-input::-webkit-scrollbar {
@@ -113,6 +114,10 @@
diff --git a/g4f/gui/client/js/chat.v1.js b/g4f/gui/client/js/chat.v1.js
index fffe9fe9..e763f52d 100644
--- a/g4f/gui/client/js/chat.v1.js
+++ b/g4f/gui/client/js/chat.v1.js
@@ -7,6 +7,7 @@ const spinner = box_conversations.querySelector(".spinner");
const stop_generating = document.querySelector(`.stop_generating`);
const regenerate = document.querySelector(`.regenerate`);
const send_button = document.querySelector(`#send-button`);
+const imageInput = document.querySelector('#image') ;
let prompt_lock = false;
hljs.addPlugin(new CopyButtonPlugin());
@@ -34,7 +35,7 @@ const delete_conversations = async () => {
};
const handle_ask = async () => {
- message_input.style.height = `80px`;
+ message_input.style.height = `82px`;
message_input.focus();
window.scrollTo(0, 0);
message = message_input.value
@@ -103,8 +104,7 @@ const ask_gpt = async () => {
`;
@@ -114,29 +114,32 @@ const ask_gpt = async () => {
message_box.scrollTop = message_box.scrollHeight;
window.scrollTo(0, 0);
try {
+ let body = JSON.stringify({
+ id: window.token,
+ conversation_id: window.conversation_id,
+ model: model.options[model.selectedIndex].value,
+ jailbreak: jailbreak.options[jailbreak.selectedIndex].value,
+ web_search: document.getElementById(`switch`).checked,
+ provider: provider.options[provider.selectedIndex].value,
+ patch_provider: document.getElementById('patch').checked,
+ messages: messages
+ });
+ const headers = {
+ accept: 'text/event-stream'
+ }
+ if (imageInput && imageInput.files.length > 0) {
+ const formData = new FormData();
+ formData.append('image', imageInput.files[0]);
+ formData.append('json', body);
+ body = formData;
+ } else {
+ headers['content-type'] = 'application/json';
+ }
const response = await fetch(`/backend-api/v2/conversation`, {
- method: `POST`,
+ method: 'POST',
signal: window.controller.signal,
- headers: {
- 'content-type': `application/json`,
- accept: `text/event-stream`,
- },
- body: JSON.stringify({
- conversation_id: window.conversation_id,
- action: `_ask`,
- model: model.options[model.selectedIndex].value,
- jailbreak: jailbreak.options[jailbreak.selectedIndex].value,
- internet_access: document.getElementById(`switch`).checked,
- provider: provider.options[provider.selectedIndex].value,
- patch_provider: document.getElementById('patch').checked,
- meta: {
- id: window.token,
- content: {
- content_type: `text`,
- parts: messages,
- },
- },
- }),
+ headers: headers,
+ body: body
});
await new Promise((r) => setTimeout(r, 1000));
@@ -159,13 +162,17 @@ const ask_gpt = async () => {
'' + provider.name + ""
} else if (message["type"] == "error") {
error = message["error"];
+ } else if (message["type"] == "message") {
+ console.error(message["message"])
}
}
if (error) {
console.error(error);
content_inner.innerHTML = "An error occured, please try again, if the problem persists, please use a other model or provider";
} else {
- content_inner.innerHTML = markdown_render(text);
+ html = markdown_render(text);
+ html = html.substring(0, html.lastIndexOf('
')) + '';
+ content_inner.innerHTML = html;
document.querySelectorAll('code').forEach((el) => {
hljs.highlightElement(el);
});
@@ -174,9 +181,9 @@ const ask_gpt = async () => {
window.scrollTo(0, 0);
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
}
+ if (!error && imageInput) imageInput.value = "";
} catch (e) {
- console.log(e);
-
+ console.error(e);
if (e.name != `AbortError`) {
text = `oops ! something went wrong, please try again / reload. [stacktrace in console]`;
@@ -444,34 +451,34 @@ document.querySelector(".mobile-sidebar").addEventListener("click", (event) => {
});
const register_settings_localstorage = async () => {
- settings_ids = ["switch", "model", "jailbreak", "patch", "provider"];
- settings_elements = settings_ids.map((id) => document.getElementById(id));
- settings_elements.map((element) =>
- element.addEventListener(`change`, async (event) => {
+ for (id of ["switch", "model", "jailbreak", "patch", "provider"]) {
+ element = document.getElementById(id);
+ element.addEventListener('change', async (event) => {
switch (event.target.type) {
case "checkbox":
- localStorage.setItem(event.target.id, event.target.checked);
+ localStorage.setItem(id, event.target.checked);
break;
case "select-one":
- localStorage.setItem(event.target.id, event.target.selectedIndex);
+ localStorage.setItem(id, event.target.selectedIndex);
break;
default:
console.warn("Unresolved element type");
}
- })
- );
-};
+ });
+ }
+}
const load_settings_localstorage = async () => {
for (id of ["switch", "model", "jailbreak", "patch", "provider"]) {
element = document.getElementById(id);
- if (localStorage.getItem(element.id)) {
+ value = localStorage.getItem(element.id);
+ if (value) {
switch (element.type) {
case "checkbox":
- element.checked = localStorage.getItem(element.id) === "true";
+ element.checked = value === "true";
break;
case "select-one":
- element.selectedIndex = parseInt(localStorage.getItem(element.id));
+ element.selectedIndex = parseInt(value);
break;
default:
console.warn("Unresolved element type");
@@ -529,7 +536,6 @@ colorThemes.forEach((themeOption) => {
window.onload = async () => {
- load_settings_localstorage();
setTheme();
let conversations = 0;
@@ -610,16 +616,14 @@ observer.observe(message_input, { attributes: true });
option.value = option.text = model;
select.appendChild(option);
}
-})();
-(async () => {
response = await fetch('/backend-api/v2/providers')
providers = await response.json()
- let select = document.getElementById('provider');
+ select = document.getElementById('provider');
select.textContent = '';
- let auto = document.createElement('option');
+ auto = document.createElement('option');
auto.value = '';
auto.text = 'Provider: Auto';
select.appendChild(auto);
@@ -629,6 +633,8 @@ observer.observe(message_input, { attributes: true });
option.value = option.text = provider;
select.appendChild(option);
}
+
+ await load_settings_localstorage()
})();
(async () => {
@@ -644,4 +650,4 @@ observer.observe(message_input, { attributes: true });
text += versions["version"];
}
document.getElementById("version_text").innerHTML = text
-})();
\ No newline at end of file
+})()
\ No newline at end of file
diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py
index 67f13de4..3ccd1a59 100644
--- a/g4f/gui/server/backend.py
+++ b/g4f/gui/server/backend.py
@@ -3,6 +3,7 @@ import json
from flask import request, Flask
from g4f import debug, version, models
from g4f import _all_models, get_last_provider, ChatCompletion
+from g4f.image import is_allowed_extension, to_image
from g4f.Provider import __providers__
from g4f.Provider.bing.create_images import patch_provider
from .internet import get_search_message
@@ -55,7 +56,7 @@ class Backend_Api:
def version(self):
return {
"version": version.utils.current_version,
- "lastet_version": version.utils.latest_version,
+ "lastet_version": version.get_latest_version(),
}
def _gen_title(self):
@@ -64,15 +65,31 @@ class Backend_Api:
}
def _conversation(self):
- #jailbreak = request.json['jailbreak']
- messages = request.json['meta']['content']['parts']
- if request.json.get('internet_access'):
- messages[-1]["content"] = get_search_message(messages[-1]["content"])
- model = request.json.get('model')
+ kwargs = {}
+ if 'image' in request.files:
+ file = request.files['image']
+ if file.filename != '' and is_allowed_extension(file.filename):
+ kwargs['image'] = to_image(file.stream)
+ if 'json' in request.form:
+ json_data = json.loads(request.form['json'])
+ else:
+ json_data = request.json
+
+ provider = json_data.get('provider', '').replace('g4f.Provider.', '')
+ provider = provider if provider and provider != "Auto" else None
+ if provider == 'OpenaiChat':
+ kwargs['auto_continue'] = True
+ messages = json_data['messages']
+ if json_data.get('web_search'):
+ if provider == "Bing":
+ kwargs['web_search'] = True
+ else:
+ messages[-1]["content"] = get_search_message(messages[-1]["content"])
+ model = json_data.get('model')
model = model if model else models.default
- provider = request.json.get('provider', '').replace('g4f.Provider.', '')
+ provider = json_data.get('provider', '').replace('g4f.Provider.', '')
provider = provider if provider and provider != "Auto" else None
- patch = patch_provider if request.json.get('patch_provider') else None
+ patch = patch_provider if json_data.get('patch_provider') else None
def try_response():
try:
@@ -83,7 +100,8 @@ class Backend_Api:
messages=messages,
stream=True,
ignore_stream_and_auth=True,
- patch_provider=patch
+ patch_provider=patch,
+ **kwargs
):
if first:
first = False
@@ -91,16 +109,24 @@ class Backend_Api:
'type' : 'provider',
'provider': get_last_provider(True)
}) + "\n"
- yield json.dumps({
- 'type' : 'content',
- 'content': chunk,
- }) + "\n"
-
+ if isinstance(chunk, Exception):
+ yield json.dumps({
+ 'type' : 'message',
+ 'message': get_error_message(chunk),
+ }) + "\n"
+ else:
+ yield json.dumps({
+ 'type' : 'content',
+ 'content': str(chunk),
+ }) + "\n"
except Exception as e:
logging.exception(e)
yield json.dumps({
'type' : 'error',
- 'error': f'{e.__class__.__name__}: {e}'
+ 'error': get_error_message(e)
})
- return self.app.response_class(try_response(), mimetype='text/event-stream')
\ No newline at end of file
+ return self.app.response_class(try_response(), mimetype='text/event-stream')
+
+def get_error_message(exception: Exception) -> str:
+ return f"{get_last_provider().__name__}: {type(exception).__name__}: {exception}"
\ No newline at end of file
diff --git a/g4f/image.py b/g4f/image.py
new file mode 100644
index 00000000..4a97247e
--- /dev/null
+++ b/g4f/image.py
@@ -0,0 +1,116 @@
+import re
+from io import BytesIO
+import base64
+from .typing import ImageType, Union
+from PIL import Image
+
+ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
+
+def to_image(image: ImageType) -> Image.Image:
+ if isinstance(image, str):
+ is_data_uri_an_image(image)
+ image = extract_data_uri(image)
+ if isinstance(image, bytes):
+ is_accepted_format(image)
+ image = Image.open(BytesIO(image))
+ elif not isinstance(image, Image.Image):
+ image = Image.open(image)
+ copy = image.copy()
+ copy.format = image.format
+ image = copy
+ return image
+
+def is_allowed_extension(filename) -> bool:
+ return '.' in filename and \
+ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
+
+def is_data_uri_an_image(data_uri: str) -> bool:
+ # Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
+ if not re.match(r'data:image/(\w+);base64,', data_uri):
+ raise ValueError("Invalid data URI image.")
+ # Extract the image format from the data URI
+ image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1)
+ # Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
+ if image_format.lower() not in ALLOWED_EXTENSIONS:
+ raise ValueError("Invalid image format (from mime file type).")
+
+def is_accepted_format(binary_data: bytes) -> bool:
+ if binary_data.startswith(b'\xFF\xD8\xFF'):
+ pass # It's a JPEG image
+ elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
+ pass # It's a PNG image
+ elif binary_data.startswith(b'GIF87a') or binary_data.startswith(b'GIF89a'):
+ pass # It's a GIF image
+ elif binary_data.startswith(b'\x89JFIF') or binary_data.startswith(b'JFIF\x00'):
+ pass # It's a JPEG image
+ elif binary_data.startswith(b'\xFF\xD8'):
+ pass # It's a JPEG image
+ elif binary_data.startswith(b'RIFF') and binary_data[8:12] == b'WEBP':
+ pass # It's a WebP image
+ else:
+ raise ValueError("Invalid image format (from magic code).")
+
+def extract_data_uri(data_uri: str) -> bytes:
+ data = data_uri.split(",")[1]
+ data = base64.b64decode(data)
+ return data
+
+def get_orientation(image: Image.Image) -> int:
+ exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
+ if exif_data is not None:
+ orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
+ if orientation is not None:
+ return orientation
+
+def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Image:
+ orientation = get_orientation(img)
+ new_img = Image.new("RGB", (new_width, new_height), color="#FFFFFF")
+ if orientation:
+ if orientation > 4:
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
+ if orientation in [3, 4]:
+ img = img.transpose(Image.ROTATE_180)
+ if orientation in [5, 6]:
+ img = img.transpose(Image.ROTATE_270)
+ if orientation in [7, 8]:
+ img = img.transpose(Image.ROTATE_90)
+ new_img.paste(img, (0, 0))
+ return new_img
+
+def to_base64(image: Image.Image, compression_rate: float) -> str:
+ output_buffer = BytesIO()
+ image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
+ return base64.b64encode(output_buffer.getvalue()).decode()
+
+def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=200") -> str:
+ if isinstance(images, list):
+ images = [f"[![#{idx+1} {prompt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)]
+ images = "\n".join(images)
+ else:
+ images = f"[![{prompt}]({images})]({images})"
+ start_flag = "\n"
+ end_flag = "\n"
+ return f"\n{start_flag}{images}\n{end_flag}\n"
+
+def to_bytes(image: Image.Image) -> bytes:
+ bytes_io = BytesIO()
+ image.save(bytes_io, image.format)
+ image.seek(0)
+ return bytes_io.getvalue()
+
+class ImageResponse():
+ def __init__(
+ self,
+ images: Union[str, list],
+ alt: str,
+ options: dict = {}
+ ):
+ self.images = images
+ self.alt = alt
+ self.options = options
+
+ def __str__(self) -> str:
+ return format_images_markdown(self.images, self.alt)
+
+ def get(self, key: str):
+ return self.options.get(key)
\ No newline at end of file
diff --git a/g4f/requests.py b/g4f/requests.py
index 467ea371..1a13dec9 100644
--- a/g4f/requests.py
+++ b/g4f/requests.py
@@ -11,12 +11,6 @@ from .webdriver import WebDriver, WebDriverSession, bypass_cloudflare, get_drive
class StreamResponse:
def __init__(self, inner: Response) -> None:
self.inner: Response = inner
- self.request = inner.request
- self.status_code: int = inner.status_code
- self.reason: str = inner.reason
- self.ok: bool = inner.ok
- self.headers = inner.headers
- self.cookies = inner.cookies
async def text(self) -> str:
return await self.inner.atext()
@@ -34,17 +28,26 @@ class StreamResponse:
async def iter_content(self) -> AsyncGenerator[bytes, None]:
async for chunk in self.inner.aiter_content():
yield chunk
+
+ async def __aenter__(self):
+ inner: Response = await self.inner
+ self.inner = inner
+ self.request = inner.request
+ self.status_code: int = inner.status_code
+ self.reason: str = inner.reason
+ self.ok: bool = inner.ok
+ self.headers = inner.headers
+ self.cookies = inner.cookies
+ return self
+
+ async def __aexit__(self, *args):
+ await self.inner.aclose()
class StreamSession(AsyncSession):
- @asynccontextmanager
- async def request(
+ def request(
self, method: str, url: str, **kwargs
- ) -> AsyncGenerator[StreamResponse]:
- response = await super().request(method, url, stream=True, **kwargs)
- try:
- yield StreamResponse(response)
- finally:
- await response.aclose()
+ ) -> StreamResponse:
+ return StreamResponse(super().request(method, url, stream=True, **kwargs))
head = partialmethod(request, "HEAD")
get = partialmethod(request, "GET")
diff --git a/g4f/typing.py b/g4f/typing.py
index c93a4bcf..c972f505 100644
--- a/g4f/typing.py
+++ b/g4f/typing.py
@@ -1,5 +1,6 @@
import sys
-from typing import Any, AsyncGenerator, Generator, NewType, Tuple, Union, List, Dict, Type
+from typing import Any, AsyncGenerator, Generator, NewType, Tuple, Union, List, Dict, Type, IO
+from PIL.Image import Image
if sys.version_info >= (3, 8):
from typing import TypedDict
@@ -10,6 +11,7 @@ SHA256 = NewType('sha_256_hash', str)
CreateResult = Generator[str, None, None]
AsyncResult = AsyncGenerator[str, None]
Messages = List[Dict[str, str]]
+ImageType = Union[str, bytes, IO, Image, None]
__all__ = [
'Any',
diff --git a/g4f/version.py b/g4f/version.py
index 44d14369..bb4b7f17 100644
--- a/g4f/version.py
+++ b/g4f/version.py
@@ -5,6 +5,15 @@ from importlib.metadata import version as get_package_version, PackageNotFoundEr
from subprocess import check_output, CalledProcessError, PIPE
from .errors import VersionNotFoundError
+def get_latest_version() -> str:
+ try:
+ get_package_version("g4f")
+ response = requests.get("https://pypi.org/pypi/g4f/json").json()
+ return response["info"]["version"]
+ except PackageNotFoundError:
+ url = "https://api.github.com/repos/xtekky/gpt4free/releases/latest"
+ response = requests.get(url).json()
+ return response["tag_name"]
class VersionUtils():
@cached_property
@@ -28,20 +37,13 @@ class VersionUtils():
@cached_property
def latest_version(self) -> str:
- try:
- get_package_version("g4f")
- response = requests.get("https://pypi.org/pypi/g4f/json").json()
- return response["info"]["version"]
- except PackageNotFoundError:
- url = "https://api.github.com/repos/xtekky/gpt4free/releases/latest"
- response = requests.get(url).json()
- return response["tag_name"]
-
- def check_pypi_version(self) -> None:
+ return get_latest_version()
+
+ def check_version(self) -> None:
try:
if self.current_version != self.latest_version:
- print(f'New pypi version: {self.latest_version} (current: {self.current_version}) | pip install -U g4f')
+ print(f'New g4f version: {self.latest_version} (current: {self.current_version}) | pip install -U g4f')
except Exception as e:
- print(f'Failed to check g4f pypi version: {e}')
+ print(f'Failed to check g4f version: {e}')
utils = VersionUtils()
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index b9212b4e..f1e5a3dd 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,12 +5,10 @@ aiohttp
certifi
browser_cookie3
websockets
-js2py
typing-extensions
PyExecJS
duckduckgo-search
nest_asyncio
-waitress
werkzeug
loguru
pillow
@@ -24,7 +22,5 @@ py-arkose-generator
asyncstdlib
async-property
undetected-chromedriver
-asyncstdlib
-async_property
brotli
beautifulsoup4
\ No newline at end of file
diff --git a/setup.py b/setup.py
index dc59597b..daafe26d 100644
--- a/setup.py
+++ b/setup.py
@@ -16,12 +16,10 @@ install_requires = [
"certifi",
"browser_cookie3",
"websockets",
- "js2py",
"typing-extensions",
"PyExecJS",
"duckduckgo-search",
"nest_asyncio",
- "waitress",
"werkzeug",
"loguru",
"pillow",
@@ -35,8 +33,6 @@ install_requires = [
"asyncstdlib",
"async-property",
"undetected-chromedriver",
- "asyncstdlib",
- "async_property",
"brotli",
"beautifulsoup4",
]
--
cgit v1.2.3