File size: 6,664 Bytes
b781aec
0269e89
6e11e8c
b781aec
 
 
6e11e8c
0b48c77
 
 
 
 
 
 
 
accf5a5
0b48c77
 
 
accf5a5
0b48c77
652b4e6
accf5a5
0b48c77
 
 
accf5a5
0b48c77
 
 
accf5a5
0b48c77
652b4e6
accf5a5
f5cc66a
652b4e6
f5cc66a
 
 
 
 
 
0b48c77
 
 
 
 
 
 
 
 
 
 
 
6e11e8c
0b48c77
 
652b4e6
 
6e11e8c
652b4e6
 
 
 
 
 
0b48c77
 
 
 
b781aec
0b48c77
 
 
 
 
f5cc66a
0b48c77
f5cc66a
0b48c77
f5cc66a
0b48c77
f5cc66a
 
0b48c77
 
f5cc66a
 
0b48c77
 
 
 
f5cc66a
0b48c77
 
 
 
f5cc66a
0b48c77
 
f5cc66a
0b48c77
 
 
 
 
 
f5cc66a
0b48c77
 
f5cc66a
0b48c77
 
 
f5cc66a
 
0b48c77
 
 
 
f5cc66a
6e11e8c
f5cc66a
 
 
 
 
 
 
 
6e11e8c
0269e89
6e11e8c
 
 
 
0269e89
6e11e8c
0269e89
 
 
0b48c77
f5cc66a
 
94742c1
6e11e8c
 
 
 
 
0269e89
 
0b48c77
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import re
import time
from concurrent.futures import ThreadPoolExecutor

import openai
import pandas as pd
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tqdm import tqdm


class AutoTab:
    def __init__(
        self,
        in_file_path: str,
        out_file_path: str,
        instruction: str,
        max_examples: int,
        model_name: str,
        generation_config: dict,
        request_interval: float,
        save_every: int,
        api_keys: list[str],
        base_url: str,
    ):
        self.in_file_path = in_file_path
        self.out_file_path = out_file_path
        self.instruction = instruction
        self.max_examples = max_examples
        self.model_name = model_name
        self.generation_config = generation_config
        self.request_interval = request_interval
        self.save_every = save_every
        self.api_keys = api_keys
        self.base_url = base_url

        self.request_count = 0
        self.failed_count = 0
        self.data, self.input_fields, self.output_fields = self.load_excel()
        self.in_context = self.derive_incontext()
        self.num_data = len(self.data)
        self.num_example = len(self.data.dropna(subset=self.output_fields))
        self.num_missing = self.num_data - self.num_example

    # ─── IO ───────────────────────────────────────────────────────────────

    def load_excel(self) -> tuple[pd.DataFrame, list, list]:
        """Load the Excel file and identify input and output fields."""
        df = pd.read_excel(self.in_file_path)
        input_fields = [col for col in df.columns if col.startswith("[Input] ")]
        output_fields = [col for col in df.columns if col.startswith("[Output] ")]
        return df, input_fields, output_fields

    # ─── LLM ──────────────────────────────────────────────────────────────

    @retry(wait=wait_random_exponential(min=20, max=60), stop=stop_after_attempt(6))
    def openai_request(self, query: str) -> str:
        """Make a request to an OpenAI-format API."""

        # Wait for the request interval
        time.sleep(self.request_interval)

        # Increment the request count
        api_key = self.api_keys[self.request_count % len(self.api_keys)]
        self.request_count += 1

        client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
        response = client.chat.completions.create(
            model=self.model_name,
            messages=[{"role": "user", "content": query}],
            **self.generation_config,
        )
        str_response = response.choices[0].message.content.strip()
        return str_response

    # ─── In-Context Learning ──────────────────────────────────────────────

    def derive_incontext(self) -> str:
        """Derive the in-context prompt with angle brackets."""
        examples = self.data.dropna(subset=self.output_fields)[: self.max_examples]
        in_context = ""
        for i in range(len(examples)):
            in_context += "".join(
                f"<{col.replace('[Input] ', '')}>{self.data[col].iloc[i]}</{col.replace('[Input] ', '')}>\n"
                for col in self.input_fields
            )
            in_context += "".join(
                f"<{col.replace('[Output] ', '')}>{self.data[col].iloc[i]}</{col.replace('[Output] ', '')}>\n"
                for col in self.output_fields
            )
            in_context += "\n"
        return in_context

    def predict_output(self, input_data: pd.DataFrame):
        """Predict the output values for the given input data using the API."""
        query = (
            self.instruction
            + "\n\n"
            + self.in_context
            + "".join(
                f"<{col.replace('[Input] ', '')}>{input_data[col]}</{col.replace('[Input] ', '')}>\n"
                for col in self.input_fields
            )
        )
        self.query_example = query
        output = self.openai_request(query)
        return output

    def extract_fields(self, response: str) -> dict[str, str]:
        """Extract fields from the response text based on output columns."""
        extracted = {}
        for col in self.output_fields:
            field = col.replace("[Output] ", "")
            match = re.search(f"<{field}>(.*?)</{field}>", response)
            extracted[col] = match.group(1) if match else ""
        if any(extracted[col] == "" for col in self.output_fields):
            self.failed_count += 1
        return extracted

    # ─── Engine ───────────────────────────────────────────────────────────

    def _predict_and_extract(self, row: int) -> dict[str, str]:
        """Helper function to predict and extract fields for a single row."""

        # If any output field is empty, predict the output
        if any(pd.isnull(self.data.at[row, col]) for col in self.output_fields):
            prediction = self.predict_output(self.data.iloc[row])
            extracted_fields = self.extract_fields(prediction)
            return extracted_fields
        else:
            return {col: self.data.at[row, col] for col in self.output_fields}

    def batch_prediction(self, start_index: int, end_index: int):
        """Process a batch of predictions asynchronously."""
        with ThreadPoolExecutor() as executor:
            results = list(
                executor.map(self._predict_and_extract, range(start_index, end_index))
            )
        for i, extracted_fields in zip(range(start_index, end_index), results):
            for field_name in self.output_fields:
                self.data.at[i, field_name] = extracted_fields.get(field_name, "")

    def run(self):
        tqdm_bar = tqdm(total=self.num_data, leave=False)
        for start in range(0, self.num_data, self.save_every):
            tqdm_bar.update(min(self.save_every, self.num_data - start))
            end = min(start + self.save_every, self.num_data)
            try:
                self.batch_prediction(start, end)
            except Exception as e:
                print(e)
            self.data.to_excel(self.out_file_path, index=False)
        self.data.to_excel(self.out_file_path, index=False)
        print(f"Results saved to {self.out_file_path}")