Spaces:
Sleeping
Sleeping
import os | |
import json | |
from g4f.providers.response import Reasoning, JsonConversation, FinishReason | |
from g4f.typing import AsyncResult, Messages | |
import json | |
import re | |
import time | |
from urllib.parse import quote_plus | |
from fastapi import FastAPI, Response, Request | |
from fastapi.responses import RedirectResponse | |
from g4f.image import images_dir, copy_images | |
import g4f.api | |
import g4f.Provider | |
from g4f.Provider.base_provider import AsyncGeneratorProvider, ProviderModelMixin | |
from g4f.typing import AsyncResult, Messages | |
from g4f.requests import StreamSession | |
from g4f.providers.response import ProviderInfo, JsonConversation, PreviewResponse, SynthesizeData, TitleGeneration, RequestLogin | |
from g4f.providers.response import Parameters, FinishReason, Usage, Reasoning | |
from g4f.errors import ModelNotSupportedError | |
from g4f import debug | |
class BackendApi(AsyncGeneratorProvider, ProviderModelMixin): | |
url = "https://ahe.hopto.org" | |
working = True | |
ssl = False | |
models = [ | |
*g4f.Provider.OpenaiAccount.get_models(), | |
*g4f.Provider.HuggingChat.get_models(), | |
"flux", | |
"flux-pro", | |
"MiniMax-01", | |
"Microsoft Copilot", | |
] | |
def get_model(cls, model): | |
if "MiniMax" in model: | |
model = "MiniMax" | |
elif "Copilot" in model: | |
model = "Copilot" | |
elif "FLUX" in model: | |
model = f"flux-{model.split('-')[-1]}" | |
elif "flux" in model: | |
model = model.split(' ')[-1] | |
elif model in g4f.Provider.OpenaiAccount.get_models(): | |
pass | |
elif model in g4f.Provider.HuggingChat.get_models(): | |
pass | |
else: | |
raise ModelNotSupportedError(f"Model: {model}") | |
return model | |
def get_provider(cls, model): | |
if model.startswith("MiniMax"): | |
return "HailuoAI" | |
elif model == "Copilot": | |
return "CopilotAccount" | |
elif model in g4f.Provider.OpenaiAccount.get_models(): | |
return "OpenaiAccount" | |
elif model in g4f.Provider.HuggingChat.get_models(): | |
return "HuggingChat" | |
return None | |
async def create_async_generator( | |
cls, | |
model: str, | |
messages: Messages, | |
api_key: str = None, | |
proxy: str = None, | |
timeout: int = 0, | |
**kwargs | |
) -> AsyncResult: | |
debug.log(f"{__name__}: {api_key}") | |
async with StreamSession( | |
proxy=proxy, | |
headers={"Accept": "text/event-stream"}, | |
timeout=timeout | |
) as session: | |
model = cls.get_model(model) | |
provider = cls.get_provider(model) | |
async with session.post(f"{cls.url}/backend-api/v2/conversation", json={ | |
"model": model, | |
"messages": messages, | |
"provider": provider, | |
**kwargs | |
}, ssl=cls.ssl) as response: | |
async for line in response.iter_lines(): | |
data = json.loads(line) | |
data_type = data.pop("type") | |
if data_type == "provider": | |
yield ProviderInfo(**data[data_type]) | |
provider = data[data_type]["name"] | |
elif data_type == "conversation": | |
yield JsonConversation(**data[data_type][provider] if provider in data[data_type] else data[data_type][""]) | |
elif data_type == "conversation_id": | |
pass | |
elif data_type == "message": | |
yield Exception(data) | |
elif data_type == "preview": | |
yield PreviewResponse(data[data_type]) | |
elif data_type == "content": | |
def on_image(match): | |
extension = match.group(3).split(".")[-1].split("?")[0] | |
extension = "" if not extension or len(extension) > 4 else f".{extension}" | |
filename = f"{int(time.time())}_{quote_plus(match.group(1)[:100], '')}{extension}" | |
download_url = f"/download/{filename}?url={cls.url}{match.group(3)}" | |
return f"[](/images/{filename})" | |
yield re.sub(r'\[\!\[(.+?)\]\(([^)]+?)\)\]\(([^)]+?)\)', on_image, data["content"]) | |
elif data_type =="synthesize": | |
yield SynthesizeData(**data[data_type]) | |
elif data_type == "parameters": | |
yield Parameters(**data[data_type]) | |
elif data_type == "usage": | |
yield Usage(**data[data_type]) | |
elif data_type == "reasoning": | |
yield Reasoning(**data) | |
elif data_type == "login": | |
pass | |
elif data_type == "title": | |
yield TitleGeneration(data[data_type]) | |
elif data_type == "finish": | |
yield FinishReason(data[data_type]["reason"]) | |
elif data_type == "log": | |
debug.log(data[data_type]) | |
else: | |
debug.log(f"Unknown data: ({data_type}) {data}") | |
g4f.Provider.__map__["Feature"] = BackendApi | |
def create_app(): | |
g4f.debug.logging = True | |
g4f.api.AppConfig.gui = True | |
g4f.api.AppConfig.demo = True | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
g4f.api.CORSMiddleware, | |
allow_origin_regex=".*", | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
api = g4f.api.Api(app) | |
api.register_routes() | |
api.register_authorization() | |
api.register_validation_exception_handler() | |
async def download(filename, request: Request): | |
filename = os.path.basename(filename) | |
target = os.path.join(images_dir, filename) | |
if not os.path.exists(target): | |
url = str(request.query_params).split("url=", 1)[1] | |
if url: | |
source_url = url.replace("%2F", "/").replace("%3A", ":").replace("%3F", "?").replace("%3D", "=") | |
await copy_images([source_url], target=target, ssl=False) | |
if not os.path.exists(target): | |
return Response(status_code=404) | |
return RedirectResponse(f"/images/{filename}") | |
gui_app = g4f.api.WSGIMiddleware(g4f.api.get_gui_app(g4f.api.AppConfig.demo)) | |
app.mount("/", gui_app) | |
return app | |
app = create_app() |