Spaces:
Sleeping
Sleeping
File size: 8,930 Bytes
f68f2ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
pagination_detector.py # pagination_detector.py import os import json from typing import List, Dict, Tuple, Union from pydantic import BaseModel, Field, ValidationError import tiktoken from dotenv import load_dotenv from openai import OpenAI import google.generativeai as genai from groq import Groq from assets import PROMPT_PAGINATION, PRICING, LLAMA_MODEL_FULLNAME, GROQ_LLAMA_MODEL_FULLNAME load_dotenv() import logging class PaginationData(BaseModel): page_urls: List[str] = Field(default_factory=list, description="List of pagination URLs, including 'Next' button URL if present") def calculate_pagination_price(token_counts: Dict[str, int], model: str) -> float: """ Calculate the price for pagination based on token counts and the selected model. Args: token_counts (Dict[str, int]): A dictionary containing 'input_tokens' and 'output_tokens'. model (str): The name of the selected model. Returns: float: The total price for the pagination operation. """ input_tokens = token_counts['input_tokens'] output_tokens = token_counts['output_tokens'] input_price = input_tokens * PRICING[model]['input'] output_price = output_tokens * PRICING[model]['output'] return input_price + output_price def detect_pagination_elements(url: str, indications: str, selected_model: str, markdown_content: str) -> Tuple[Union[PaginationData, Dict, str], Dict, float]: try: """ Uses AI models to analyze markdown content and extract pagination elements. Args: selected_model (str): The name of the OpenAI model to use. markdown_content (str): The markdown content to analyze. Returns: Tuple[PaginationData, Dict, float]: Parsed pagination data, token counts, and pagination price. """ prompt_pagination = PROMPT_PAGINATION+"\n The url of the page to extract pagination from "+url+"if the urls that you find are not complete combine them intelligently in a way that fit the pattern **ALWAYS GIVE A FULL URL**" if indications != "": prompt_pagination +=PROMPT_PAGINATION+"\n\n these are the users indications that, pay special attention to them: "+indications+"\n\n below are the markdowns of the website: \n\n" else: prompt_pagination +=PROMPT_PAGINATION+"\n There are no user indications in this case just apply the logic described. \n\n below are the markdowns of the website: \n\n" if selected_model in ["gpt-4o-mini", "gpt-4o-2024-08-06"]: # Use OpenAI API client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) completion = client.beta.chat.completions.parse( model=selected_model, messages=[ {"role": "system", "content": prompt_pagination}, {"role": "user", "content": markdown_content}, ], response_format=PaginationData ) # Extract the parsed response parsed_response = completion.choices[0].message.parsed # Calculate tokens using tiktoken encoder = tiktoken.encoding_for_model(selected_model) input_token_count = len(encoder.encode(markdown_content)) output_token_count = len(encoder.encode(json.dumps(parsed_response.dict()))) token_counts = { "input_tokens": input_token_count, "output_tokens": output_token_count } # Calculate the price pagination_price = calculate_pagination_price(token_counts, selected_model) return parsed_response, token_counts, pagination_price elif selected_model == "gemini-1.5-flash": # Use Google Gemini API genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) model = genai.GenerativeModel( 'gemini-1.5-flash', generation_config={ "response_mime_type": "application/json", "response_schema": PaginationData } ) prompt = f"{prompt_pagination}\n{markdown_content}" # Count input tokens using Gemini's method input_tokens = model.count_tokens(prompt) completion = model.generate_content(prompt) # Extract token counts from usage_metadata usage_metadata = completion.usage_metadata token_counts = { "input_tokens": usage_metadata.prompt_token_count, "output_tokens": usage_metadata.candidates_token_count } # Get the result response_content = completion.text # Log the response content and its type logging.info(f"Gemini Flash response type: {type(response_content)}") logging.info(f"Gemini Flash response content: {response_content}") # Try to parse the response as JSON try: parsed_data = json.loads(response_content) if isinstance(parsed_data, dict) and 'page_urls' in parsed_data: pagination_data = PaginationData(**parsed_data) else: pagination_data = PaginationData(page_urls=[]) except json.JSONDecodeError: logging.error("Failed to parse Gemini Flash response as JSON") pagination_data = PaginationData(page_urls=[]) # Calculate the price pagination_price = calculate_pagination_price(token_counts, selected_model) return pagination_data, token_counts, pagination_price elif selected_model == "Llama3.1 8B": # Use Llama model via OpenAI API pointing to local server openai.api_key = "lm-studio" openai.api_base = "http://localhost:1234/v1" response = openai.ChatCompletion.create( model=LLAMA_MODEL_FULLNAME, messages=[ {"role": "system", "content": prompt_pagination}, {"role": "user", "content": markdown_content}, ], temperature=0.7, ) response_content = response['choices'][0]['message']['content'].strip() # Try to parse the JSON try: pagination_data = json.loads(response_content) except json.JSONDecodeError: pagination_data = {"next_buttons": [], "page_urls": []} # Token counts token_counts = { "input_tokens": response['usage']['prompt_tokens'], "output_tokens": response['usage']['completion_tokens'] } # Calculate the price pagination_price = calculate_pagination_price(token_counts, selected_model) return pagination_data, token_counts, pagination_price elif selected_model == "Groq Llama3.1 70b": # Use Groq client client = Groq(api_key=os.environ.get("GROQ_API_KEY")) response = client.chat.completions.create( model=GROQ_LLAMA_MODEL_FULLNAME, messages=[ {"role": "system", "content": prompt_pagination}, {"role": "user", "content": markdown_content}, ], ) response_content = response.choices[0].message.content.strip() # Try to parse the JSON try: pagination_data = json.loads(response_content) except json.JSONDecodeError: pagination_data = {"page_urls": []} # Token counts token_counts = { "input_tokens": response.usage.prompt_tokens, "output_tokens": response.usage.completion_tokens } # Calculate the price pagination_price = calculate_pagination_price(token_counts, selected_model) # Ensure the pagination_data is a dictionary if isinstance(pagination_data, PaginationData): pagination_data = pagination_data.dict() elif not isinstance(pagination_data, dict): pagination_data = {"page_urls": []} return pagination_data, token_counts, pagination_price else: raise ValueError(f"Unsupported model: {selected_model}") except Exception as e: logging.error(f"An error occurred in detect_pagination_elements: {e}") # Return default values if an error occurs return PaginationData(page_urls=[]), {"input_tokens": 0, "output_tokens": 0}, 0.0 |