|
from contextlib import asynccontextmanager |
|
from typing import Optional |
|
from fastapi import FastAPI |
|
from pydantic import BaseModel, Field |
|
from playwright.async_api import async_playwright, Browser, BrowserContext, Page |
|
from urllib.parse import quote_plus |
|
import logging |
|
import re |
|
import uvicorn |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
playwright = None |
|
pw_browser: Browser = None |
|
|
|
|
|
@asynccontextmanager |
|
async def api_lifespan(app: FastAPI): |
|
global playwright, pw_browser |
|
playwright = await async_playwright().start() |
|
pw_browser = await playwright.chromium.launch(headless=True) |
|
|
|
yield |
|
|
|
await pw_browser.close() |
|
await playwright.stop() |
|
|
|
app = FastAPI(lifespan=api_lifespan) |
|
|
|
|
|
class APISearchParams(BaseModel): |
|
queries: list[str] = Field(..., |
|
description="The list of queries to search for") |
|
n_results: int = Field( |
|
10, description="Number of results to return for each query. Valid values are 10, 25, 50 and 100") |
|
|
|
|
|
class APIPatentResults(BaseModel): |
|
"""Response of /search_patents endpoint""" |
|
error: Optional[str] |
|
results: Optional[list[dict]] |
|
|
|
|
|
class APIBraveResults(BaseModel): |
|
"""Response of /search_brave endpoint""" |
|
error: Optional[str] |
|
results: Optional[list[dict]] |
|
|
|
|
|
async def query_google_patents(browser: Browser, q: str, n_results: int = 10): |
|
"""Queries google patents for the specified query and number of results. Returns relevant patents""" |
|
context: BrowserContext = await browser.new_context() |
|
page: Page = await context.new_page() |
|
|
|
async def _block_resources(route, request): |
|
if request.resource_type in ["stylesheet", "image"]: |
|
await route.abort() |
|
else: |
|
await route.continue_() |
|
|
|
await page.route("**/*", _block_resources) |
|
|
|
url = f"https://patents.google.com/?q=({quote_plus(q)})&oq={quote_plus(q)}&num={n_results}" |
|
await page.goto(url) |
|
|
|
await page.wait_for_function( |
|
f"""() => document.querySelectorAll('search-result-item').length >= {n_results}""", |
|
timeout=30_000 |
|
) |
|
|
|
|
|
PATENT_ID_REGEX = r"\b[A-Z]{2}\d{6,}(?:[A-Z]\d?)?\b" |
|
|
|
items = await page.locator("search-result-item").all() |
|
matches = [] |
|
for item in items: |
|
all_text = " ".join(await item.locator("span").all_inner_texts()) |
|
found = re.findall(PATENT_ID_REGEX, all_text) |
|
if found: |
|
matches.append(found[0]) |
|
|
|
await context.close() |
|
return matches |
|
|
|
|
|
async def query_brave_search(browser: Browser, q: str, n_results: int = 10): |
|
"""Queries brave search for the specified query""" |
|
context: BrowserContext = await browser.new_context() |
|
page: Page = await context.new_page() |
|
|
|
async def _block_resources(route, request): |
|
if request.resource_type in ["stylesheet", "image"]: |
|
await route.abort() |
|
else: |
|
await route.continue_() |
|
|
|
await page.route("**/*", _block_resources) |
|
|
|
url = f"https://search.brave.com/search?q={quote_plus(q)}" |
|
await page.goto(url) |
|
|
|
results_cards = await page.locator('.snippet').all() |
|
|
|
results = [] |
|
|
|
for result in results_cards: |
|
title = await result.locator('.title').all_inner_texts() |
|
description = await result.locator('.snippet-description').all_inner_texts() |
|
url = await result.locator('a').nth(0).get_attribute('href') |
|
|
|
if url.startswith('/'): |
|
continue |
|
|
|
results.append({"title": title[0] if len(title) > 0 else "", "body": description[0] if len( |
|
description) > 0 else "", "href": url}) |
|
|
|
return results[:n_results] |
|
|
|
|
|
@app.post("/search_scholar") |
|
async def query_google_scholar(params: APISearchParams): |
|
"""Queries google scholar for the specified query""" |
|
return {"error": "Unimplemented"} |
|
|
|
|
|
@app.get('/') |
|
async def status(): |
|
return {"status": "running"} |
|
|
|
|
|
@app.post("/search_patents") |
|
async def search_patents(params: APISearchParams) -> APIPatentResults: |
|
"""Searches google patents for the specified queries and returns the found documents.""" |
|
results = [] |
|
for q in params.queries: |
|
logging.info(f"Searching Google Patents with query `{q}`") |
|
try: |
|
res = await query_google_patents(pw_browser, q, params.n_results) |
|
results.extend(res) |
|
except Exception as e: |
|
logging.error( |
|
f"Failed to query Google Patents with query `{q}`: {e}") |
|
return APIPatentResults(results=[{"href": f"https://patents.google.com/patent/{id}/en", "id": id} for id in results], error=None) |
|
|
|
|
|
@app.post("/search_brave") |
|
async def search_brave(params: APISearchParams) -> APIBraveResults: |
|
results = [] |
|
for q in params.queries: |
|
logging.info(f"Searching Brave search with query `{q}`") |
|
try: |
|
res = await query_brave_search(pw_browser, q, params.n_results) |
|
results.extend(res) |
|
except Exception as e: |
|
logging.error( |
|
f"Failed to query Brave search with query `{q}`: {e}") |
|
|
|
return APIBraveResults(results=results, error=None) |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|