mobenta commited on
Commit
0dbb012
·
verified ·
1 Parent(s): a54cb5b

Delete pagination_detector.py.txt

Browse files
Files changed (1) hide show
  1. pagination_detector.py.txt +0 -206
pagination_detector.py.txt DELETED
@@ -1,206 +0,0 @@
1
- pagination_detector.py
2
-
3
-
4
- # pagination_detector.py
5
-
6
- import os
7
- import json
8
- from typing import List, Dict, Tuple, Union
9
- from pydantic import BaseModel, Field, ValidationError
10
-
11
- import tiktoken
12
- from dotenv import load_dotenv
13
-
14
- from openai import OpenAI
15
- import google.generativeai as genai
16
- from groq import Groq
17
-
18
- from assets import PROMPT_PAGINATION, PRICING, LLAMA_MODEL_FULLNAME, GROQ_LLAMA_MODEL_FULLNAME
19
-
20
- load_dotenv()
21
- import logging
22
-
23
- class PaginationData(BaseModel):
24
- page_urls: List[str] = Field(default_factory=list, description="List of pagination URLs, including 'Next' button URL if present")
25
-
26
- def calculate_pagination_price(token_counts: Dict[str, int], model: str) -> float:
27
- """
28
- Calculate the price for pagination based on token counts and the selected model.
29
-
30
- Args:
31
- token_counts (Dict[str, int]): A dictionary containing 'input_tokens' and 'output_tokens'.
32
- model (str): The name of the selected model.
33
-
34
- Returns:
35
- float: The total price for the pagination operation.
36
- """
37
- input_tokens = token_counts['input_tokens']
38
- output_tokens = token_counts['output_tokens']
39
-
40
- input_price = input_tokens * PRICING[model]['input']
41
- output_price = output_tokens * PRICING[model]['output']
42
-
43
- return input_price + output_price
44
-
45
- def detect_pagination_elements(url: str, indications: str, selected_model: str, markdown_content: str) -> Tuple[Union[PaginationData, Dict, str], Dict, float]:
46
- try:
47
- """
48
- Uses AI models to analyze markdown content and extract pagination elements.
49
-
50
- Args:
51
- selected_model (str): The name of the OpenAI model to use.
52
- markdown_content (str): The markdown content to analyze.
53
-
54
- Returns:
55
- Tuple[PaginationData, Dict, float]: Parsed pagination data, token counts, and pagination price.
56
- """
57
- prompt_pagination = PROMPT_PAGINATION+"\n The url of the page to extract pagination from "+url+"if the urls that you find are not complete combine them intelligently in a way that fit the pattern **ALWAYS GIVE A FULL URL**"
58
- if indications != "":
59
- prompt_pagination +=PROMPT_PAGINATION+"\n\n these are the users indications that, pay special attention to them: "+indications+"\n\n below are the markdowns of the website: \n\n"
60
- else:
61
- prompt_pagination +=PROMPT_PAGINATION+"\n There are no user indications in this case just apply the logic described. \n\n below are the markdowns of the website: \n\n"
62
-
63
- if selected_model in ["gpt-4o-mini", "gpt-4o-2024-08-06"]:
64
- # Use OpenAI API
65
- client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
66
- completion = client.beta.chat.completions.parse(
67
- model=selected_model,
68
- messages=[
69
- {"role": "system", "content": prompt_pagination},
70
- {"role": "user", "content": markdown_content},
71
- ],
72
- response_format=PaginationData
73
- )
74
-
75
- # Extract the parsed response
76
- parsed_response = completion.choices[0].message.parsed
77
-
78
- # Calculate tokens using tiktoken
79
- encoder = tiktoken.encoding_for_model(selected_model)
80
- input_token_count = len(encoder.encode(markdown_content))
81
- output_token_count = len(encoder.encode(json.dumps(parsed_response.dict())))
82
- token_counts = {
83
- "input_tokens": input_token_count,
84
- "output_tokens": output_token_count
85
- }
86
-
87
- # Calculate the price
88
- pagination_price = calculate_pagination_price(token_counts, selected_model)
89
-
90
- return parsed_response, token_counts, pagination_price
91
-
92
- elif selected_model == "gemini-1.5-flash":
93
- # Use Google Gemini API
94
- genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
95
- model = genai.GenerativeModel(
96
- 'gemini-1.5-flash',
97
- generation_config={
98
- "response_mime_type": "application/json",
99
- "response_schema": PaginationData
100
- }
101
- )
102
- prompt = f"{prompt_pagination}\n{markdown_content}"
103
- # Count input tokens using Gemini's method
104
- input_tokens = model.count_tokens(prompt)
105
- completion = model.generate_content(prompt)
106
- # Extract token counts from usage_metadata
107
- usage_metadata = completion.usage_metadata
108
- token_counts = {
109
- "input_tokens": usage_metadata.prompt_token_count,
110
- "output_tokens": usage_metadata.candidates_token_count
111
- }
112
- # Get the result
113
- response_content = completion.text
114
-
115
- # Log the response content and its type
116
- logging.info(f"Gemini Flash response type: {type(response_content)}")
117
- logging.info(f"Gemini Flash response content: {response_content}")
118
-
119
- # Try to parse the response as JSON
120
- try:
121
- parsed_data = json.loads(response_content)
122
- if isinstance(parsed_data, dict) and 'page_urls' in parsed_data:
123
- pagination_data = PaginationData(**parsed_data)
124
- else:
125
- pagination_data = PaginationData(page_urls=[])
126
- except json.JSONDecodeError:
127
- logging.error("Failed to parse Gemini Flash response as JSON")
128
- pagination_data = PaginationData(page_urls=[])
129
-
130
- # Calculate the price
131
- pagination_price = calculate_pagination_price(token_counts, selected_model)
132
-
133
- return pagination_data, token_counts, pagination_price
134
-
135
- elif selected_model == "Llama3.1 8B":
136
- # Use Llama model via OpenAI API pointing to local server
137
- openai.api_key = "lm-studio"
138
- openai.api_base = "http://localhost:1234/v1"
139
- response = openai.ChatCompletion.create(
140
- model=LLAMA_MODEL_FULLNAME,
141
- messages=[
142
- {"role": "system", "content": prompt_pagination},
143
- {"role": "user", "content": markdown_content},
144
- ],
145
- temperature=0.7,
146
- )
147
- response_content = response['choices'][0]['message']['content'].strip()
148
- # Try to parse the JSON
149
- try:
150
- pagination_data = json.loads(response_content)
151
- except json.JSONDecodeError:
152
- pagination_data = {"next_buttons": [], "page_urls": []}
153
- # Token counts
154
- token_counts = {
155
- "input_tokens": response['usage']['prompt_tokens'],
156
- "output_tokens": response['usage']['completion_tokens']
157
- }
158
- # Calculate the price
159
- pagination_price = calculate_pagination_price(token_counts, selected_model)
160
-
161
- return pagination_data, token_counts, pagination_price
162
-
163
- elif selected_model == "Groq Llama3.1 70b":
164
- # Use Groq client
165
- client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
166
- response = client.chat.completions.create(
167
- model=GROQ_LLAMA_MODEL_FULLNAME,
168
- messages=[
169
- {"role": "system", "content": prompt_pagination},
170
- {"role": "user", "content": markdown_content},
171
- ],
172
- )
173
- response_content = response.choices[0].message.content.strip()
174
- # Try to parse the JSON
175
- try:
176
- pagination_data = json.loads(response_content)
177
- except json.JSONDecodeError:
178
- pagination_data = {"page_urls": []}
179
- # Token counts
180
- token_counts = {
181
- "input_tokens": response.usage.prompt_tokens,
182
- "output_tokens": response.usage.completion_tokens
183
- }
184
- # Calculate the price
185
- pagination_price = calculate_pagination_price(token_counts, selected_model)
186
-
187
- # Ensure the pagination_data is a dictionary
188
- if isinstance(pagination_data, PaginationData):
189
- pagination_data = pagination_data.dict()
190
- elif not isinstance(pagination_data, dict):
191
- pagination_data = {"page_urls": []}
192
-
193
- return pagination_data, token_counts, pagination_price
194
-
195
- else:
196
- raise ValueError(f"Unsupported model: {selected_model}")
197
-
198
- except Exception as e:
199
- logging.error(f"An error occurred in detect_pagination_elements: {e}")
200
- # Return default values if an error occurs
201
- return PaginationData(page_urls=[]), {"input_tokens": 0, "output_tokens": 0}, 0.0
202
-
203
-
204
-
205
-
206
-