File size: 6,664 Bytes
b781aec 0269e89 6e11e8c b781aec 6e11e8c 0b48c77 accf5a5 0b48c77 accf5a5 0b48c77 652b4e6 accf5a5 0b48c77 accf5a5 0b48c77 accf5a5 0b48c77 652b4e6 accf5a5 f5cc66a 652b4e6 f5cc66a 0b48c77 6e11e8c 0b48c77 652b4e6 6e11e8c 652b4e6 0b48c77 b781aec 0b48c77 f5cc66a 0b48c77 f5cc66a 0b48c77 f5cc66a 0b48c77 f5cc66a 0b48c77 f5cc66a 0b48c77 f5cc66a 0b48c77 f5cc66a 0b48c77 f5cc66a 0b48c77 f5cc66a 0b48c77 f5cc66a 0b48c77 f5cc66a 0b48c77 f5cc66a 6e11e8c f5cc66a 6e11e8c 0269e89 6e11e8c 0269e89 6e11e8c 0269e89 0b48c77 f5cc66a 94742c1 6e11e8c 0269e89 0b48c77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import re
import time
from concurrent.futures import ThreadPoolExecutor
import openai
import pandas as pd
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tqdm import tqdm
class AutoTab:
def __init__(
self,
in_file_path: str,
out_file_path: str,
instruction: str,
max_examples: int,
model_name: str,
generation_config: dict,
request_interval: float,
save_every: int,
api_keys: list[str],
base_url: str,
):
self.in_file_path = in_file_path
self.out_file_path = out_file_path
self.instruction = instruction
self.max_examples = max_examples
self.model_name = model_name
self.generation_config = generation_config
self.request_interval = request_interval
self.save_every = save_every
self.api_keys = api_keys
self.base_url = base_url
self.request_count = 0
self.failed_count = 0
self.data, self.input_fields, self.output_fields = self.load_excel()
self.in_context = self.derive_incontext()
self.num_data = len(self.data)
self.num_example = len(self.data.dropna(subset=self.output_fields))
self.num_missing = self.num_data - self.num_example
# βββ IO βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def load_excel(self) -> tuple[pd.DataFrame, list, list]:
"""Load the Excel file and identify input and output fields."""
df = pd.read_excel(self.in_file_path)
input_fields = [col for col in df.columns if col.startswith("[Input] ")]
output_fields = [col for col in df.columns if col.startswith("[Output] ")]
return df, input_fields, output_fields
# βββ LLM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@retry(wait=wait_random_exponential(min=20, max=60), stop=stop_after_attempt(6))
def openai_request(self, query: str) -> str:
"""Make a request to an OpenAI-format API."""
# Wait for the request interval
time.sleep(self.request_interval)
# Increment the request count
api_key = self.api_keys[self.request_count % len(self.api_keys)]
self.request_count += 1
client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": query}],
**self.generation_config,
)
str_response = response.choices[0].message.content.strip()
return str_response
# βββ In-Context Learning ββββββββββββββββββββββββββββββββββββββββββββββ
def derive_incontext(self) -> str:
"""Derive the in-context prompt with angle brackets."""
examples = self.data.dropna(subset=self.output_fields)[: self.max_examples]
in_context = ""
for i in range(len(examples)):
in_context += "".join(
f"<{col.replace('[Input] ', '')}>{self.data[col].iloc[i]}</{col.replace('[Input] ', '')}>\n"
for col in self.input_fields
)
in_context += "".join(
f"<{col.replace('[Output] ', '')}>{self.data[col].iloc[i]}</{col.replace('[Output] ', '')}>\n"
for col in self.output_fields
)
in_context += "\n"
return in_context
def predict_output(self, input_data: pd.DataFrame):
"""Predict the output values for the given input data using the API."""
query = (
self.instruction
+ "\n\n"
+ self.in_context
+ "".join(
f"<{col.replace('[Input] ', '')}>{input_data[col]}</{col.replace('[Input] ', '')}>\n"
for col in self.input_fields
)
)
self.query_example = query
output = self.openai_request(query)
return output
def extract_fields(self, response: str) -> dict[str, str]:
"""Extract fields from the response text based on output columns."""
extracted = {}
for col in self.output_fields:
field = col.replace("[Output] ", "")
match = re.search(f"<{field}>(.*?)</{field}>", response)
extracted[col] = match.group(1) if match else ""
if any(extracted[col] == "" for col in self.output_fields):
self.failed_count += 1
return extracted
# βββ Engine βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _predict_and_extract(self, row: int) -> dict[str, str]:
"""Helper function to predict and extract fields for a single row."""
# If any output field is empty, predict the output
if any(pd.isnull(self.data.at[row, col]) for col in self.output_fields):
prediction = self.predict_output(self.data.iloc[row])
extracted_fields = self.extract_fields(prediction)
return extracted_fields
else:
return {col: self.data.at[row, col] for col in self.output_fields}
def batch_prediction(self, start_index: int, end_index: int):
"""Process a batch of predictions asynchronously."""
with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._predict_and_extract, range(start_index, end_index))
)
for i, extracted_fields in zip(range(start_index, end_index), results):
for field_name in self.output_fields:
self.data.at[i, field_name] = extracted_fields.get(field_name, "")
def run(self):
tqdm_bar = tqdm(total=self.num_data, leave=False)
for start in range(0, self.num_data, self.save_every):
tqdm_bar.update(min(self.save_every, self.num_data - start))
end = min(start + self.save_every, self.num_data)
try:
self.batch_prediction(start, end)
except Exception as e:
print(e)
self.data.to_excel(self.out_file_path, index=False)
self.data.to_excel(self.out_file_path, index=False)
print(f"Results saved to {self.out_file_path}")
|