Spaces:
Sleeping
Sleeping
# pagination_detector.py | |
import os | |
import json | |
from typing import List, Dict, Tuple, Union | |
from pydantic import BaseModel, Field, ValidationError | |
import tiktoken | |
from dotenv import load_dotenv | |
from openai import OpenAI | |
import google.generativeai as genai | |
from groq import Groq | |
from assets import PROMPT_PAGINATION, PRICING, LLAMA_MODEL_FULLNAME, GROQ_LLAMA_MODEL_FULLNAME | |
load_dotenv() | |
import logging | |
class PaginationData(BaseModel): | |
page_urls: List[str] = Field(default_factory=list, description="List of pagination URLs, including 'Next' button URL if present") | |
def calculate_pagination_price(token_counts: Dict[str, int], model: str) -> float: | |
""" | |
Calculate the price for pagination based on token counts and the selected model. | |
Args: | |
token_counts (Dict[str, int]): A dictionary containing 'input_tokens' and 'output_tokens'. | |
model (str): The name of the selected model. | |
Returns: | |
float: The total price for the pagination operation. | |
""" | |
input_tokens = token_counts['input_tokens'] | |
output_tokens = token_counts['output_tokens'] | |
input_price = input_tokens * PRICING[model]['input'] | |
output_price = output_tokens * PRICING[model]['output'] | |
return input_price + output_price | |
def detect_pagination_elements(url: str, indications: str, selected_model: str, markdown_content: str) -> Tuple[Union[PaginationData, Dict, str], Dict, float]: | |
try: | |
""" | |
Uses AI models to analyze markdown content and extract pagination elements. | |
Args: | |
selected_model (str): The name of the OpenAI model to use. | |
markdown_content (str): The markdown content to analyze. | |
Returns: | |
Tuple[PaginationData, Dict, float]: Parsed pagination data, token counts, and pagination price. | |
""" | |
prompt_pagination = PROMPT_PAGINATION+"\n The url of the page to extract pagination from "+url+"if the urls that you find are not complete combine them intelligently in a way that fit the pattern **ALWAYS GIVE A FULL URL**" | |
if indications != "": | |
prompt_pagination +=PROMPT_PAGINATION+"\n\n these are the users indications that, pay special attention to them: "+indications+"\n\n below are the markdowns of the website: \n\n" | |
else: | |
prompt_pagination +=PROMPT_PAGINATION+"\n There are no user indications in this case just apply the logic described. \n\n below are the markdowns of the website: \n\n" | |
if selected_model in ["gpt-4o-mini", "gpt-4o-2024-08-06"]: | |
# Use OpenAI API | |
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) | |
completion = client.beta.chat.completions.parse( | |
model=selected_model, | |
messages=[ | |
{"role": "system", "content": prompt_pagination}, | |
{"role": "user", "content": markdown_content}, | |
], | |
response_format=PaginationData | |
) | |
# Extract the parsed response | |
parsed_response = completion.choices[0].message.parsed | |
# Calculate tokens using tiktoken | |
encoder = tiktoken.encoding_for_model(selected_model) | |
input_token_count = len(encoder.encode(markdown_content)) | |
output_token_count = len(encoder.encode(json.dumps(parsed_response.dict()))) | |
token_counts = { | |
"input_tokens": input_token_count, | |
"output_tokens": output_token_count | |
} | |
# Calculate the price | |
pagination_price = calculate_pagination_price(token_counts, selected_model) | |
return parsed_response, token_counts, pagination_price | |
elif selected_model == "gemini-1.5-flash": | |
# Use Google Gemini API | |
genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) | |
model = genai.GenerativeModel( | |
'gemini-1.5-flash', | |
generation_config={ | |
"response_mime_type": "application/json", | |
"response_schema": PaginationData | |
} | |
) | |
prompt = f"{prompt_pagination}\n{markdown_content}" | |
# Count input tokens using Gemini's method | |
input_tokens = model.count_tokens(prompt) | |
completion = model.generate_content(prompt) | |
# Extract token counts from usage_metadata | |
usage_metadata = completion.usage_metadata | |
token_counts = { | |
"input_tokens": usage_metadata.prompt_token_count, | |
"output_tokens": usage_metadata.candidates_token_count | |
} | |
# Get the result | |
response_content = completion.text | |
# Log the response content and its type | |
logging.info(f"Gemini Flash response type: {type(response_content)}") | |
logging.info(f"Gemini Flash response content: {response_content}") | |
# Try to parse the response as JSON | |
try: | |
parsed_data = json.loads(response_content) | |
if isinstance(parsed_data, dict) and 'page_urls' in parsed_data: | |
pagination_data = PaginationData(**parsed_data) | |
else: | |
pagination_data = PaginationData(page_urls=[]) | |
except json.JSONDecodeError: | |
logging.error("Failed to parse Gemini Flash response as JSON") | |
pagination_data = PaginationData(page_urls=[]) | |
# Calculate the price | |
pagination_price = calculate_pagination_price(token_counts, selected_model) | |
return pagination_data, token_counts, pagination_price | |
elif selected_model == "Llama3.1 8B": | |
# Use Llama model via OpenAI API pointing to local server | |
openai.api_key = "lm-studio" | |
openai.api_base = "http://localhost:1234/v1" | |
response = openai.ChatCompletion.create( | |
model=LLAMA_MODEL_FULLNAME, | |
messages=[ | |
{"role": "system", "content": prompt_pagination}, | |
{"role": "user", "content": markdown_content}, | |
], | |
temperature=0.7, | |
) | |
response_content = response['choices'][0]['message']['content'].strip() | |
# Try to parse the JSON | |
try: | |
pagination_data = json.loads(response_content) | |
except json.JSONDecodeError: | |
pagination_data = {"next_buttons": [], "page_urls": []} | |
# Token counts | |
token_counts = { | |
"input_tokens": response['usage']['prompt_tokens'], | |
"output_tokens": response['usage']['completion_tokens'] | |
} | |
# Calculate the price | |
pagination_price = calculate_pagination_price(token_counts, selected_model) | |
return pagination_data, token_counts, pagination_price | |
elif selected_model == "Groq Llama3.1 70b": | |
# Use Groq client | |
client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
response = client.chat.completions.create( | |
model=GROQ_LLAMA_MODEL_FULLNAME, | |
messages=[ | |
{"role": "system", "content": prompt_pagination}, | |
{"role": "user", "content": markdown_content}, | |
], | |
) | |
response_content = response.choices[0].message.content.strip() | |
# Try to parse the JSON | |
try: | |
pagination_data = json.loads(response_content) | |
except json.JSONDecodeError: | |
pagination_data = {"page_urls": []} | |
# Token counts | |
token_counts = { | |
"input_tokens": response.usage.prompt_tokens, | |
"output_tokens": response.usage.completion_tokens | |
} | |
# Calculate the price | |
pagination_price = calculate_pagination_price(token_counts, selected_model) | |
# Ensure the pagination_data is a dictionary | |
if isinstance(pagination_data, PaginationData): | |
pagination_data = pagination_data.dict() | |
elif not isinstance(pagination_data, dict): | |
pagination_data = {"page_urls": []} | |
return pagination_data, token_counts, pagination_price | |
else: | |
raise ValueError(f"Unsupported model: {selected_model}") | |
except Exception as e: | |
logging.error(f"An error occurred in detect_pagination_elements: {e}") | |
# Return default values if an error occurs | |
return PaginationData(page_urls=[]), {"input_tokens": 0, "output_tokens": 0}, 0.0 | |