mobenta commited on
Commit
c04f6ef
·
verified ·
1 Parent(s): 6f84c42

Rename pagination_detector.py.txt to pagination_detector.py

Browse files
pagination_detector.py.txt → pagination_detector.py RENAMED
@@ -1,206 +1,206 @@
1
- pagination_detector.py
2
-
3
-
4
- # pagination_detector.py
5
-
6
- import os
7
- import json
8
- from typing import List, Dict, Tuple, Union
9
- from pydantic import BaseModel, Field, ValidationError
10
-
11
- import tiktoken
12
- from dotenv import load_dotenv
13
-
14
- from openai import OpenAI
15
- import google.generativeai as genai
16
- from groq import Groq
17
-
18
- from assets import PROMPT_PAGINATION, PRICING, LLAMA_MODEL_FULLNAME, GROQ_LLAMA_MODEL_FULLNAME
19
-
20
- load_dotenv()
21
- import logging
22
-
23
- class PaginationData(BaseModel):
24
- page_urls: List[str] = Field(default_factory=list, description="List of pagination URLs, including 'Next' button URL if present")
25
-
26
- def calculate_pagination_price(token_counts: Dict[str, int], model: str) -> float:
27
- """
28
- Calculate the price for pagination based on token counts and the selected model.
29
-
30
- Args:
31
- token_counts (Dict[str, int]): A dictionary containing 'input_tokens' and 'output_tokens'.
32
- model (str): The name of the selected model.
33
-
34
- Returns:
35
- float: The total price for the pagination operation.
36
- """
37
- input_tokens = token_counts['input_tokens']
38
- output_tokens = token_counts['output_tokens']
39
-
40
- input_price = input_tokens * PRICING[model]['input']
41
- output_price = output_tokens * PRICING[model]['output']
42
-
43
- return input_price + output_price
44
-
45
- def detect_pagination_elements(url: str, indications: str, selected_model: str, markdown_content: str) -> Tuple[Union[PaginationData, Dict, str], Dict, float]:
46
- try:
47
- """
48
- Uses AI models to analyze markdown content and extract pagination elements.
49
-
50
- Args:
51
- selected_model (str): The name of the OpenAI model to use.
52
- markdown_content (str): The markdown content to analyze.
53
-
54
- Returns:
55
- Tuple[PaginationData, Dict, float]: Parsed pagination data, token counts, and pagination price.
56
- """
57
- prompt_pagination = PROMPT_PAGINATION+"\n The url of the page to extract pagination from "+url+"if the urls that you find are not complete combine them intelligently in a way that fit the pattern **ALWAYS GIVE A FULL URL**"
58
- if indications != "":
59
- prompt_pagination +=PROMPT_PAGINATION+"\n\n these are the users indications that, pay special attention to them: "+indications+"\n\n below are the markdowns of the website: \n\n"
60
- else:
61
- prompt_pagination +=PROMPT_PAGINATION+"\n There are no user indications in this case just apply the logic described. \n\n below are the markdowns of the website: \n\n"
62
-
63
- if selected_model in ["gpt-4o-mini", "gpt-4o-2024-08-06"]:
64
- # Use OpenAI API
65
- client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
66
- completion = client.beta.chat.completions.parse(
67
- model=selected_model,
68
- messages=[
69
- {"role": "system", "content": prompt_pagination},
70
- {"role": "user", "content": markdown_content},
71
- ],
72
- response_format=PaginationData
73
- )
74
-
75
- # Extract the parsed response
76
- parsed_response = completion.choices[0].message.parsed
77
-
78
- # Calculate tokens using tiktoken
79
- encoder = tiktoken.encoding_for_model(selected_model)
80
- input_token_count = len(encoder.encode(markdown_content))
81
- output_token_count = len(encoder.encode(json.dumps(parsed_response.dict())))
82
- token_counts = {
83
- "input_tokens": input_token_count,
84
- "output_tokens": output_token_count
85
- }
86
-
87
- # Calculate the price
88
- pagination_price = calculate_pagination_price(token_counts, selected_model)
89
-
90
- return parsed_response, token_counts, pagination_price
91
-
92
- elif selected_model == "gemini-1.5-flash":
93
- # Use Google Gemini API
94
- genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
95
- model = genai.GenerativeModel(
96
- 'gemini-1.5-flash',
97
- generation_config={
98
- "response_mime_type": "application/json",
99
- "response_schema": PaginationData
100
- }
101
- )
102
- prompt = f"{prompt_pagination}\n{markdown_content}"
103
- # Count input tokens using Gemini's method
104
- input_tokens = model.count_tokens(prompt)
105
- completion = model.generate_content(prompt)
106
- # Extract token counts from usage_metadata
107
- usage_metadata = completion.usage_metadata
108
- token_counts = {
109
- "input_tokens": usage_metadata.prompt_token_count,
110
- "output_tokens": usage_metadata.candidates_token_count
111
- }
112
- # Get the result
113
- response_content = completion.text
114
-
115
- # Log the response content and its type
116
- logging.info(f"Gemini Flash response type: {type(response_content)}")
117
- logging.info(f"Gemini Flash response content: {response_content}")
118
-
119
- # Try to parse the response as JSON
120
- try:
121
- parsed_data = json.loads(response_content)
122
- if isinstance(parsed_data, dict) and 'page_urls' in parsed_data:
123
- pagination_data = PaginationData(**parsed_data)
124
- else:
125
- pagination_data = PaginationData(page_urls=[])
126
- except json.JSONDecodeError:
127
- logging.error("Failed to parse Gemini Flash response as JSON")
128
- pagination_data = PaginationData(page_urls=[])
129
-
130
- # Calculate the price
131
- pagination_price = calculate_pagination_price(token_counts, selected_model)
132
-
133
- return pagination_data, token_counts, pagination_price
134
-
135
- elif selected_model == "Llama3.1 8B":
136
- # Use Llama model via OpenAI API pointing to local server
137
- openai.api_key = "lm-studio"
138
- openai.api_base = "http://localhost:1234/v1"
139
- response = openai.ChatCompletion.create(
140
- model=LLAMA_MODEL_FULLNAME,
141
- messages=[
142
- {"role": "system", "content": prompt_pagination},
143
- {"role": "user", "content": markdown_content},
144
- ],
145
- temperature=0.7,
146
- )
147
- response_content = response['choices'][0]['message']['content'].strip()
148
- # Try to parse the JSON
149
- try:
150
- pagination_data = json.loads(response_content)
151
- except json.JSONDecodeError:
152
- pagination_data = {"next_buttons": [], "page_urls": []}
153
- # Token counts
154
- token_counts = {
155
- "input_tokens": response['usage']['prompt_tokens'],
156
- "output_tokens": response['usage']['completion_tokens']
157
- }
158
- # Calculate the price
159
- pagination_price = calculate_pagination_price(token_counts, selected_model)
160
-
161
- return pagination_data, token_counts, pagination_price
162
-
163
- elif selected_model == "Groq Llama3.1 70b":
164
- # Use Groq client
165
- client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
166
- response = client.chat.completions.create(
167
- model=GROQ_LLAMA_MODEL_FULLNAME,
168
- messages=[
169
- {"role": "system", "content": prompt_pagination},
170
- {"role": "user", "content": markdown_content},
171
- ],
172
- )
173
- response_content = response.choices[0].message.content.strip()
174
- # Try to parse the JSON
175
- try:
176
- pagination_data = json.loads(response_content)
177
- except json.JSONDecodeError:
178
- pagination_data = {"page_urls": []}
179
- # Token counts
180
- token_counts = {
181
- "input_tokens": response.usage.prompt_tokens,
182
- "output_tokens": response.usage.completion_tokens
183
- }
184
- # Calculate the price
185
- pagination_price = calculate_pagination_price(token_counts, selected_model)
186
-
187
- # Ensure the pagination_data is a dictionary
188
- if isinstance(pagination_data, PaginationData):
189
- pagination_data = pagination_data.dict()
190
- elif not isinstance(pagination_data, dict):
191
- pagination_data = {"page_urls": []}
192
-
193
- return pagination_data, token_counts, pagination_price
194
-
195
- else:
196
- raise ValueError(f"Unsupported model: {selected_model}")
197
-
198
- except Exception as e:
199
- logging.error(f"An error occurred in detect_pagination_elements: {e}")
200
- # Return default values if an error occurs
201
- return PaginationData(page_urls=[]), {"input_tokens": 0, "output_tokens": 0}, 0.0
202
-
203
-
204
-
205
-
206
 
 
1
+
2
+
3
+
4
+ # pagination_detector.py
5
+
6
+ import os
7
+ import json
8
+ from typing import List, Dict, Tuple, Union
9
+ from pydantic import BaseModel, Field, ValidationError
10
+
11
+ import tiktoken
12
+ from dotenv import load_dotenv
13
+
14
+ from openai import OpenAI
15
+ import google.generativeai as genai
16
+ from groq import Groq
17
+
18
+ from assets import PROMPT_PAGINATION, PRICING, LLAMA_MODEL_FULLNAME, GROQ_LLAMA_MODEL_FULLNAME
19
+
20
+ load_dotenv()
21
+ import logging
22
+
23
+ class PaginationData(BaseModel):
24
+ page_urls: List[str] = Field(default_factory=list, description="List of pagination URLs, including 'Next' button URL if present")
25
+
26
+ def calculate_pagination_price(token_counts: Dict[str, int], model: str) -> float:
27
+ """
28
+ Calculate the price for pagination based on token counts and the selected model.
29
+
30
+ Args:
31
+ token_counts (Dict[str, int]): A dictionary containing 'input_tokens' and 'output_tokens'.
32
+ model (str): The name of the selected model.
33
+
34
+ Returns:
35
+ float: The total price for the pagination operation.
36
+ """
37
+ input_tokens = token_counts['input_tokens']
38
+ output_tokens = token_counts['output_tokens']
39
+
40
+ input_price = input_tokens * PRICING[model]['input']
41
+ output_price = output_tokens * PRICING[model]['output']
42
+
43
+ return input_price + output_price
44
+
45
+ def detect_pagination_elements(url: str, indications: str, selected_model: str, markdown_content: str) -> Tuple[Union[PaginationData, Dict, str], Dict, float]:
46
+ try:
47
+ """
48
+ Uses AI models to analyze markdown content and extract pagination elements.
49
+
50
+ Args:
51
+ selected_model (str): The name of the OpenAI model to use.
52
+ markdown_content (str): The markdown content to analyze.
53
+
54
+ Returns:
55
+ Tuple[PaginationData, Dict, float]: Parsed pagination data, token counts, and pagination price.
56
+ """
57
+ prompt_pagination = PROMPT_PAGINATION+"\n The url of the page to extract pagination from "+url+"if the urls that you find are not complete combine them intelligently in a way that fit the pattern **ALWAYS GIVE A FULL URL**"
58
+ if indications != "":
59
+ prompt_pagination +=PROMPT_PAGINATION+"\n\n these are the users indications that, pay special attention to them: "+indications+"\n\n below are the markdowns of the website: \n\n"
60
+ else:
61
+ prompt_pagination +=PROMPT_PAGINATION+"\n There are no user indications in this case just apply the logic described. \n\n below are the markdowns of the website: \n\n"
62
+
63
+ if selected_model in ["gpt-4o-mini", "gpt-4o-2024-08-06"]:
64
+ # Use OpenAI API
65
+ client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
66
+ completion = client.beta.chat.completions.parse(
67
+ model=selected_model,
68
+ messages=[
69
+ {"role": "system", "content": prompt_pagination},
70
+ {"role": "user", "content": markdown_content},
71
+ ],
72
+ response_format=PaginationData
73
+ )
74
+
75
+ # Extract the parsed response
76
+ parsed_response = completion.choices[0].message.parsed
77
+
78
+ # Calculate tokens using tiktoken
79
+ encoder = tiktoken.encoding_for_model(selected_model)
80
+ input_token_count = len(encoder.encode(markdown_content))
81
+ output_token_count = len(encoder.encode(json.dumps(parsed_response.dict())))
82
+ token_counts = {
83
+ "input_tokens": input_token_count,
84
+ "output_tokens": output_token_count
85
+ }
86
+
87
+ # Calculate the price
88
+ pagination_price = calculate_pagination_price(token_counts, selected_model)
89
+
90
+ return parsed_response, token_counts, pagination_price
91
+
92
+ elif selected_model == "gemini-1.5-flash":
93
+ # Use Google Gemini API
94
+ genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
95
+ model = genai.GenerativeModel(
96
+ 'gemini-1.5-flash',
97
+ generation_config={
98
+ "response_mime_type": "application/json",
99
+ "response_schema": PaginationData
100
+ }
101
+ )
102
+ prompt = f"{prompt_pagination}\n{markdown_content}"
103
+ # Count input tokens using Gemini's method
104
+ input_tokens = model.count_tokens(prompt)
105
+ completion = model.generate_content(prompt)
106
+ # Extract token counts from usage_metadata
107
+ usage_metadata = completion.usage_metadata
108
+ token_counts = {
109
+ "input_tokens": usage_metadata.prompt_token_count,
110
+ "output_tokens": usage_metadata.candidates_token_count
111
+ }
112
+ # Get the result
113
+ response_content = completion.text
114
+
115
+ # Log the response content and its type
116
+ logging.info(f"Gemini Flash response type: {type(response_content)}")
117
+ logging.info(f"Gemini Flash response content: {response_content}")
118
+
119
+ # Try to parse the response as JSON
120
+ try:
121
+ parsed_data = json.loads(response_content)
122
+ if isinstance(parsed_data, dict) and 'page_urls' in parsed_data:
123
+ pagination_data = PaginationData(**parsed_data)
124
+ else:
125
+ pagination_data = PaginationData(page_urls=[])
126
+ except json.JSONDecodeError:
127
+ logging.error("Failed to parse Gemini Flash response as JSON")
128
+ pagination_data = PaginationData(page_urls=[])
129
+
130
+ # Calculate the price
131
+ pagination_price = calculate_pagination_price(token_counts, selected_model)
132
+
133
+ return pagination_data, token_counts, pagination_price
134
+
135
+ elif selected_model == "Llama3.1 8B":
136
+ # Use Llama model via OpenAI API pointing to local server
137
+ openai.api_key = "lm-studio"
138
+ openai.api_base = "http://localhost:1234/v1"
139
+ response = openai.ChatCompletion.create(
140
+ model=LLAMA_MODEL_FULLNAME,
141
+ messages=[
142
+ {"role": "system", "content": prompt_pagination},
143
+ {"role": "user", "content": markdown_content},
144
+ ],
145
+ temperature=0.7,
146
+ )
147
+ response_content = response['choices'][0]['message']['content'].strip()
148
+ # Try to parse the JSON
149
+ try:
150
+ pagination_data = json.loads(response_content)
151
+ except json.JSONDecodeError:
152
+ pagination_data = {"next_buttons": [], "page_urls": []}
153
+ # Token counts
154
+ token_counts = {
155
+ "input_tokens": response['usage']['prompt_tokens'],
156
+ "output_tokens": response['usage']['completion_tokens']
157
+ }
158
+ # Calculate the price
159
+ pagination_price = calculate_pagination_price(token_counts, selected_model)
160
+
161
+ return pagination_data, token_counts, pagination_price
162
+
163
+ elif selected_model == "Groq Llama3.1 70b":
164
+ # Use Groq client
165
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
166
+ response = client.chat.completions.create(
167
+ model=GROQ_LLAMA_MODEL_FULLNAME,
168
+ messages=[
169
+ {"role": "system", "content": prompt_pagination},
170
+ {"role": "user", "content": markdown_content},
171
+ ],
172
+ )
173
+ response_content = response.choices[0].message.content.strip()
174
+ # Try to parse the JSON
175
+ try:
176
+ pagination_data = json.loads(response_content)
177
+ except json.JSONDecodeError:
178
+ pagination_data = {"page_urls": []}
179
+ # Token counts
180
+ token_counts = {
181
+ "input_tokens": response.usage.prompt_tokens,
182
+ "output_tokens": response.usage.completion_tokens
183
+ }
184
+ # Calculate the price
185
+ pagination_price = calculate_pagination_price(token_counts, selected_model)
186
+
187
+ # Ensure the pagination_data is a dictionary
188
+ if isinstance(pagination_data, PaginationData):
189
+ pagination_data = pagination_data.dict()
190
+ elif not isinstance(pagination_data, dict):
191
+ pagination_data = {"page_urls": []}
192
+
193
+ return pagination_data, token_counts, pagination_price
194
+
195
+ else:
196
+ raise ValueError(f"Unsupported model: {selected_model}")
197
+
198
+ except Exception as e:
199
+ logging.error(f"An error occurred in detect_pagination_elements: {e}")
200
+ # Return default values if an error occurs
201
+ return PaginationData(page_urls=[]), {"input_tokens": 0, "output_tokens": 0}, 0.0
202
+
203
+
204
+
205
+
206