Spaces:
Sleeping
Sleeping
File size: 8,703 Bytes
c04f6ef 1a979d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
# pagination_detector.py
import os
import json
from typing import List, Dict, Tuple, Union
from pydantic import BaseModel, Field, ValidationError
import tiktoken
from dotenv import load_dotenv
from openai import OpenAI
import google.generativeai as genai
from groq import Groq
from assets import PROMPT_PAGINATION, PRICING, LLAMA_MODEL_FULLNAME, GROQ_LLAMA_MODEL_FULLNAME
load_dotenv()
import logging
class PaginationData(BaseModel):
page_urls: List[str] = Field(default_factory=list, description="List of pagination URLs, including 'Next' button URL if present")
def calculate_pagination_price(token_counts: Dict[str, int], model: str) -> float:
"""
Calculate the price for pagination based on token counts and the selected model.
Args:
token_counts (Dict[str, int]): A dictionary containing 'input_tokens' and 'output_tokens'.
model (str): The name of the selected model.
Returns:
float: The total price for the pagination operation.
"""
input_tokens = token_counts['input_tokens']
output_tokens = token_counts['output_tokens']
input_price = input_tokens * PRICING[model]['input']
output_price = output_tokens * PRICING[model]['output']
return input_price + output_price
def detect_pagination_elements(url: str, indications: str, selected_model: str, markdown_content: str) -> Tuple[Union[PaginationData, Dict, str], Dict, float]:
try:
"""
Uses AI models to analyze markdown content and extract pagination elements.
Args:
selected_model (str): The name of the OpenAI model to use.
markdown_content (str): The markdown content to analyze.
Returns:
Tuple[PaginationData, Dict, float]: Parsed pagination data, token counts, and pagination price.
"""
prompt_pagination = PROMPT_PAGINATION+"\n The url of the page to extract pagination from "+url+"if the urls that you find are not complete combine them intelligently in a way that fit the pattern **ALWAYS GIVE A FULL URL**"
if indications != "":
prompt_pagination +=PROMPT_PAGINATION+"\n\n these are the users indications that, pay special attention to them: "+indications+"\n\n below are the markdowns of the website: \n\n"
else:
prompt_pagination +=PROMPT_PAGINATION+"\n There are no user indications in this case just apply the logic described. \n\n below are the markdowns of the website: \n\n"
if selected_model in ["gpt-4o-mini", "gpt-4o-2024-08-06"]:
# Use OpenAI API
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
completion = client.beta.chat.completions.parse(
model=selected_model,
messages=[
{"role": "system", "content": prompt_pagination},
{"role": "user", "content": markdown_content},
],
response_format=PaginationData
)
# Extract the parsed response
parsed_response = completion.choices[0].message.parsed
# Calculate tokens using tiktoken
encoder = tiktoken.encoding_for_model(selected_model)
input_token_count = len(encoder.encode(markdown_content))
output_token_count = len(encoder.encode(json.dumps(parsed_response.dict())))
token_counts = {
"input_tokens": input_token_count,
"output_tokens": output_token_count
}
# Calculate the price
pagination_price = calculate_pagination_price(token_counts, selected_model)
return parsed_response, token_counts, pagination_price
elif selected_model == "gemini-1.5-flash":
# Use Google Gemini API
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
model = genai.GenerativeModel(
'gemini-1.5-flash',
generation_config={
"response_mime_type": "application/json",
"response_schema": PaginationData
}
)
prompt = f"{prompt_pagination}\n{markdown_content}"
# Count input tokens using Gemini's method
input_tokens = model.count_tokens(prompt)
completion = model.generate_content(prompt)
# Extract token counts from usage_metadata
usage_metadata = completion.usage_metadata
token_counts = {
"input_tokens": usage_metadata.prompt_token_count,
"output_tokens": usage_metadata.candidates_token_count
}
# Get the result
response_content = completion.text
# Log the response content and its type
logging.info(f"Gemini Flash response type: {type(response_content)}")
logging.info(f"Gemini Flash response content: {response_content}")
# Try to parse the response as JSON
try:
parsed_data = json.loads(response_content)
if isinstance(parsed_data, dict) and 'page_urls' in parsed_data:
pagination_data = PaginationData(**parsed_data)
else:
pagination_data = PaginationData(page_urls=[])
except json.JSONDecodeError:
logging.error("Failed to parse Gemini Flash response as JSON")
pagination_data = PaginationData(page_urls=[])
# Calculate the price
pagination_price = calculate_pagination_price(token_counts, selected_model)
return pagination_data, token_counts, pagination_price
elif selected_model == "Llama3.1 8B":
# Use Llama model via OpenAI API pointing to local server
openai.api_key = "lm-studio"
openai.api_base = "http://localhost:1234/v1"
response = openai.ChatCompletion.create(
model=LLAMA_MODEL_FULLNAME,
messages=[
{"role": "system", "content": prompt_pagination},
{"role": "user", "content": markdown_content},
],
temperature=0.7,
)
response_content = response['choices'][0]['message']['content'].strip()
# Try to parse the JSON
try:
pagination_data = json.loads(response_content)
except json.JSONDecodeError:
pagination_data = {"next_buttons": [], "page_urls": []}
# Token counts
token_counts = {
"input_tokens": response['usage']['prompt_tokens'],
"output_tokens": response['usage']['completion_tokens']
}
# Calculate the price
pagination_price = calculate_pagination_price(token_counts, selected_model)
return pagination_data, token_counts, pagination_price
elif selected_model == "Groq Llama3.1 70b":
# Use Groq client
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
response = client.chat.completions.create(
model=GROQ_LLAMA_MODEL_FULLNAME,
messages=[
{"role": "system", "content": prompt_pagination},
{"role": "user", "content": markdown_content},
],
)
response_content = response.choices[0].message.content.strip()
# Try to parse the JSON
try:
pagination_data = json.loads(response_content)
except json.JSONDecodeError:
pagination_data = {"page_urls": []}
# Token counts
token_counts = {
"input_tokens": response.usage.prompt_tokens,
"output_tokens": response.usage.completion_tokens
}
# Calculate the price
pagination_price = calculate_pagination_price(token_counts, selected_model)
# Ensure the pagination_data is a dictionary
if isinstance(pagination_data, PaginationData):
pagination_data = pagination_data.dict()
elif not isinstance(pagination_data, dict):
pagination_data = {"page_urls": []}
return pagination_data, token_counts, pagination_price
else:
raise ValueError(f"Unsupported model: {selected_model}")
except Exception as e:
logging.error(f"An error occurred in detect_pagination_elements: {e}")
# Return default values if an error occurs
return PaginationData(page_urls=[]), {"input_tokens": 0, "output_tokens": 0}, 0.0
|