webscarper / pagination_detector.py
mobenta's picture
Rename pagination_detector.py.txt to pagination_detector.py
c04f6ef verified
raw
history blame
8.7 kB
# pagination_detector.py
import os
import json
from typing import List, Dict, Tuple, Union
from pydantic import BaseModel, Field, ValidationError
import tiktoken
from dotenv import load_dotenv
from openai import OpenAI
import google.generativeai as genai
from groq import Groq
from assets import PROMPT_PAGINATION, PRICING, LLAMA_MODEL_FULLNAME, GROQ_LLAMA_MODEL_FULLNAME
load_dotenv()
import logging
class PaginationData(BaseModel):
page_urls: List[str] = Field(default_factory=list, description="List of pagination URLs, including 'Next' button URL if present")
def calculate_pagination_price(token_counts: Dict[str, int], model: str) -> float:
"""
Calculate the price for pagination based on token counts and the selected model.
Args:
token_counts (Dict[str, int]): A dictionary containing 'input_tokens' and 'output_tokens'.
model (str): The name of the selected model.
Returns:
float: The total price for the pagination operation.
"""
input_tokens = token_counts['input_tokens']
output_tokens = token_counts['output_tokens']
input_price = input_tokens * PRICING[model]['input']
output_price = output_tokens * PRICING[model]['output']
return input_price + output_price
def detect_pagination_elements(url: str, indications: str, selected_model: str, markdown_content: str) -> Tuple[Union[PaginationData, Dict, str], Dict, float]:
try:
"""
Uses AI models to analyze markdown content and extract pagination elements.
Args:
selected_model (str): The name of the OpenAI model to use.
markdown_content (str): The markdown content to analyze.
Returns:
Tuple[PaginationData, Dict, float]: Parsed pagination data, token counts, and pagination price.
"""
prompt_pagination = PROMPT_PAGINATION+"\n The url of the page to extract pagination from "+url+"if the urls that you find are not complete combine them intelligently in a way that fit the pattern **ALWAYS GIVE A FULL URL**"
if indications != "":
prompt_pagination +=PROMPT_PAGINATION+"\n\n these are the users indications that, pay special attention to them: "+indications+"\n\n below are the markdowns of the website: \n\n"
else:
prompt_pagination +=PROMPT_PAGINATION+"\n There are no user indications in this case just apply the logic described. \n\n below are the markdowns of the website: \n\n"
if selected_model in ["gpt-4o-mini", "gpt-4o-2024-08-06"]:
# Use OpenAI API
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
completion = client.beta.chat.completions.parse(
model=selected_model,
messages=[
{"role": "system", "content": prompt_pagination},
{"role": "user", "content": markdown_content},
],
response_format=PaginationData
)
# Extract the parsed response
parsed_response = completion.choices[0].message.parsed
# Calculate tokens using tiktoken
encoder = tiktoken.encoding_for_model(selected_model)
input_token_count = len(encoder.encode(markdown_content))
output_token_count = len(encoder.encode(json.dumps(parsed_response.dict())))
token_counts = {
"input_tokens": input_token_count,
"output_tokens": output_token_count
}
# Calculate the price
pagination_price = calculate_pagination_price(token_counts, selected_model)
return parsed_response, token_counts, pagination_price
elif selected_model == "gemini-1.5-flash":
# Use Google Gemini API
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
model = genai.GenerativeModel(
'gemini-1.5-flash',
generation_config={
"response_mime_type": "application/json",
"response_schema": PaginationData
}
)
prompt = f"{prompt_pagination}\n{markdown_content}"
# Count input tokens using Gemini's method
input_tokens = model.count_tokens(prompt)
completion = model.generate_content(prompt)
# Extract token counts from usage_metadata
usage_metadata = completion.usage_metadata
token_counts = {
"input_tokens": usage_metadata.prompt_token_count,
"output_tokens": usage_metadata.candidates_token_count
}
# Get the result
response_content = completion.text
# Log the response content and its type
logging.info(f"Gemini Flash response type: {type(response_content)}")
logging.info(f"Gemini Flash response content: {response_content}")
# Try to parse the response as JSON
try:
parsed_data = json.loads(response_content)
if isinstance(parsed_data, dict) and 'page_urls' in parsed_data:
pagination_data = PaginationData(**parsed_data)
else:
pagination_data = PaginationData(page_urls=[])
except json.JSONDecodeError:
logging.error("Failed to parse Gemini Flash response as JSON")
pagination_data = PaginationData(page_urls=[])
# Calculate the price
pagination_price = calculate_pagination_price(token_counts, selected_model)
return pagination_data, token_counts, pagination_price
elif selected_model == "Llama3.1 8B":
# Use Llama model via OpenAI API pointing to local server
openai.api_key = "lm-studio"
openai.api_base = "http://localhost:1234/v1"
response = openai.ChatCompletion.create(
model=LLAMA_MODEL_FULLNAME,
messages=[
{"role": "system", "content": prompt_pagination},
{"role": "user", "content": markdown_content},
],
temperature=0.7,
)
response_content = response['choices'][0]['message']['content'].strip()
# Try to parse the JSON
try:
pagination_data = json.loads(response_content)
except json.JSONDecodeError:
pagination_data = {"next_buttons": [], "page_urls": []}
# Token counts
token_counts = {
"input_tokens": response['usage']['prompt_tokens'],
"output_tokens": response['usage']['completion_tokens']
}
# Calculate the price
pagination_price = calculate_pagination_price(token_counts, selected_model)
return pagination_data, token_counts, pagination_price
elif selected_model == "Groq Llama3.1 70b":
# Use Groq client
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
response = client.chat.completions.create(
model=GROQ_LLAMA_MODEL_FULLNAME,
messages=[
{"role": "system", "content": prompt_pagination},
{"role": "user", "content": markdown_content},
],
)
response_content = response.choices[0].message.content.strip()
# Try to parse the JSON
try:
pagination_data = json.loads(response_content)
except json.JSONDecodeError:
pagination_data = {"page_urls": []}
# Token counts
token_counts = {
"input_tokens": response.usage.prompt_tokens,
"output_tokens": response.usage.completion_tokens
}
# Calculate the price
pagination_price = calculate_pagination_price(token_counts, selected_model)
# Ensure the pagination_data is a dictionary
if isinstance(pagination_data, PaginationData):
pagination_data = pagination_data.dict()
elif not isinstance(pagination_data, dict):
pagination_data = {"page_urls": []}
return pagination_data, token_counts, pagination_price
else:
raise ValueError(f"Unsupported model: {selected_model}")
except Exception as e:
logging.error(f"An error occurred in detect_pagination_elements: {e}")
# Return default values if an error occurs
return PaginationData(page_urls=[]), {"input_tokens": 0, "output_tokens": 0}, 0.0