SERPent2 / serp /base.py
Game4all's picture
Initial commit
d907837
from abc import ABC, ABCMeta, abstractmethod
from contextlib import asynccontextmanager
import logging
from typing import Literal, Optional
from httpx import AsyncClient
from pydantic import BaseModel, Field
from playwright.async_api import Browser, BrowserContext, Page
from asyncio import Semaphore
# ========================== Schemas ==========================
class SerpQuery(BaseModel):
"""Model for SERP query"""
query: str = Field(
..., description="The query to search for")
n_results: int = Field(
10, description="Number of results to return for each query. Valid values are 10, 25, 50 and 100")
sort_by: Literal["relevance",
"date"] = Field(default="relevance", description="How to sort search results.")
class SerpResultItem(BaseModel):
"""Model for a single SERP result item"""
title: str = Field(..., description="Title of the search result")
href: str = Field(..., description="URL of the search result")
body: Optional[str] = Field(
None, description="Snippet of the search result")
content_slug: Optional[str] = Field(
None, description="Content slug of the search result. A slug that encodes the content type and URL that can be used to fetch the full content later")
class Config:
extra = "allow" # Allow additional fields in the result item
# =============================== Base classes ===============================
class SERPBackendBase(ABC):
"""Base class for SERP scrapping backends"""
def __init__(self):
pass
@property
@abstractmethod
def name(self) -> str:
"""Name of the backend. Used for identification in slugs"""
pass
@property
@abstractmethod
def category(self) -> Literal["general", "patent", "scholar"]:
"""Content category that the backend provides. Used for search_auto """
pass
@abstractmethod
async def query(self, query: SerpQuery, client: AsyncClient) -> list[SerpResultItem]:
"""Perform a SERP query and return results"""
pass
class PlaywrightSerpBackendBase(SERPBackendBase):
"""Base class for SERP scrapping backends using Playwright"""
def __init__(self):
pass
async def query(self, query: SerpQuery, client: AsyncClient) -> list[SerpResultItem]:
"""Perform a SERP query and return results using Playwright"""
raise NotImplementedError("query_page method must be used instead")
@abstractmethod
async def query_serp_page(self, browser: Browser, query: SerpQuery) -> list[SerpResultItem]:
"""Perform a SERP query using Playwright and return results"""
pass
async def query_serp_backend(backend: SERPBackendBase, query: SerpQuery, client: AsyncClient, browser: Browser) -> list[SerpResultItem]:
"""Queries the given backend with the given SERP query."""
logging.info(f"Querying {backend.name} with {query}")
if isinstance(backend, PlaywrightSerpBackendBase):
return await backend.query_serp_page(browser, query)
else:
return await backend.query(query, client)
def get_backends_doc(backends: list[SERPBackendBase]) -> str:
"""Retrieves all the available backends and builds a list for doc"""
doc_str = "### Available SERP Backends \n\n\n "
for backend in backends:
doc_str += f" \n\n `{backend.name}` - category: `{backend.category}`"
return doc_str
@asynccontextmanager
async def playwright_open_page(browser: Browser, sema: Semaphore):
"""Context manager for playwright pages"""
# Acquire the concurrency semaphore
await sema.acquire()
context: BrowserContext = await browser.new_context()
page: Page = await context.new_page()
try:
yield page
finally:
await page.close()
await context.close()
sema.release()