|
from abc import ABC, ABCMeta, abstractmethod |
|
from contextlib import asynccontextmanager |
|
import logging |
|
from typing import Literal, Optional |
|
from httpx import AsyncClient |
|
from pydantic import BaseModel, Field |
|
from playwright.async_api import Browser, BrowserContext, Page |
|
from asyncio import Semaphore |
|
|
|
|
|
|
|
|
|
class SerpQuery(BaseModel): |
|
"""Model for SERP query""" |
|
query: str = Field( |
|
..., description="The query to search for") |
|
n_results: int = Field( |
|
10, description="Number of results to return for each query. Valid values are 10, 25, 50 and 100") |
|
sort_by: Literal["relevance", |
|
"date"] = Field(default="relevance", description="How to sort search results.") |
|
|
|
|
|
class SerpResultItem(BaseModel): |
|
"""Model for a single SERP result item""" |
|
title: str = Field(..., description="Title of the search result") |
|
href: str = Field(..., description="URL of the search result") |
|
body: Optional[str] = Field( |
|
None, description="Snippet of the search result") |
|
content_slug: Optional[str] = Field( |
|
None, description="Content slug of the search result. A slug that encodes the content type and URL that can be used to fetch the full content later") |
|
|
|
class Config: |
|
extra = "allow" |
|
|
|
|
|
|
|
|
|
|
|
class SERPBackendBase(ABC): |
|
"""Base class for SERP scrapping backends""" |
|
|
|
def __init__(self): |
|
pass |
|
|
|
@property |
|
@abstractmethod |
|
def name(self) -> str: |
|
"""Name of the backend. Used for identification in slugs""" |
|
pass |
|
|
|
@property |
|
@abstractmethod |
|
def category(self) -> Literal["general", "patent", "scholar"]: |
|
"""Content category that the backend provides. Used for search_auto """ |
|
pass |
|
|
|
@abstractmethod |
|
async def query(self, query: SerpQuery, client: AsyncClient) -> list[SerpResultItem]: |
|
"""Perform a SERP query and return results""" |
|
pass |
|
|
|
|
|
class PlaywrightSerpBackendBase(SERPBackendBase): |
|
"""Base class for SERP scrapping backends using Playwright""" |
|
|
|
def __init__(self): |
|
pass |
|
|
|
async def query(self, query: SerpQuery, client: AsyncClient) -> list[SerpResultItem]: |
|
"""Perform a SERP query and return results using Playwright""" |
|
raise NotImplementedError("query_page method must be used instead") |
|
|
|
@abstractmethod |
|
async def query_serp_page(self, browser: Browser, query: SerpQuery) -> list[SerpResultItem]: |
|
"""Perform a SERP query using Playwright and return results""" |
|
pass |
|
|
|
|
|
async def query_serp_backend(backend: SERPBackendBase, query: SerpQuery, client: AsyncClient, browser: Browser) -> list[SerpResultItem]: |
|
"""Queries the given backend with the given SERP query.""" |
|
logging.info(f"Querying {backend.name} with {query}") |
|
if isinstance(backend, PlaywrightSerpBackendBase): |
|
return await backend.query_serp_page(browser, query) |
|
else: |
|
return await backend.query(query, client) |
|
|
|
|
|
def get_backends_doc(backends: list[SERPBackendBase]) -> str: |
|
"""Retrieves all the available backends and builds a list for doc""" |
|
doc_str = "### Available SERP Backends \n\n\n " |
|
for backend in backends: |
|
doc_str += f" \n\n `{backend.name}` - category: `{backend.category}`" |
|
|
|
return doc_str |
|
|
|
|
|
@asynccontextmanager |
|
async def playwright_open_page(browser: Browser, sema: Semaphore): |
|
"""Context manager for playwright pages""" |
|
|
|
await sema.acquire() |
|
context: BrowserContext = await browser.new_context() |
|
page: Page = await context.new_page() |
|
try: |
|
yield page |
|
finally: |
|
await page.close() |
|
await context.close() |
|
sema.release() |
|
|