File size: 6,312 Bytes
55638b2
bc989cb
 
55638b2
 
 
bc989cb
55638b2
 
 
 
 
 
 
 
bc989cb
55638b2
 
 
bc989cb
55638b2
bc989cb
 
55638b2
 
 
bc989cb
55638b2
 
 
bc989cb
55638b2
bc989cb
 
55638b2
 
 
 
 
 
 
 
 
 
 
 
bc989cb
55638b2
 
bc989cb
55638b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc989cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55638b2
bc989cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55638b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import re
import time
from concurrent.futures import ThreadPoolExecutor

import openai
import pandas as pd
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tqdm import tqdm


class AutoTab:
    def __init__(
        self,
        in_file_path: str,
        out_file_path: str,
        instruction: str,
        max_examples: int,
        model_name: str,
        generation_config: dict,
        request_interval: float,
        save_every: int,
        api_key: str,
        base_url: str,
    ):
        self.in_file_path = in_file_path
        self.out_file_path = out_file_path
        self.instruction = instruction
        self.max_examples = max_examples
        self.model_name = model_name
        self.generation_config = generation_config
        self.request_interval = request_interval
        self.save_every = save_every
        self.api_key = api_key
        self.base_url = base_url

    # ─── IO ───────────────────────────────────────────────────────────────

    def load_excel(self) -> tuple[pd.DataFrame, list, list]:
        """Load the Excel file and identify input and output fields."""
        df = pd.read_excel(self.in_file_path)
        input_fields = [col for col in df.columns if col.startswith("[Input] ")]
        output_fields = [col for col in df.columns if col.startswith("[Output] ")]
        return df, input_fields, output_fields

    # ─── LLM ──────────────────────────────────────────────────────────────

    @retry(wait=wait_random_exponential(min=20, max=60), stop=stop_after_attempt(6))
    def openai_request(self, query: str) -> str:
        """Make a request to an OpenAI-format API."""
        time.sleep(self.request_interval)
        client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
        response = client.chat.completions.create(
            model=self.model_name,
            messages=[{"role": "user", "content": query}],
            **self.generation_config,
        )
        str_response = response.choices[0].message.content.strip()
        return str_response

    # ─── In-Context Learning ──────────────────────────────────────────────

    def derive_incontext(
        self, data: pd.DataFrame, input_columns: list[str], output_columns: list[str]
    ) -> str:
        """Derive the in-context prompt with angle brackets."""
        n = min(self.max_examples, len(data.dropna(subset=output_columns)))
        in_context = ""
        for i in range(n):
            in_context += "".join(
                f"<{col.replace('[Input] ', '')}>{data[col].iloc[i]}</{col.replace('[Input] ', '')}>\n"
                for col in input_columns
            )
            in_context += "".join(
                f"<{col.replace('[Output] ', '')}>{data[col].iloc[i]}</{col.replace('[Output] ', '')}>\n"
                for col in output_columns
            )
            in_context += "\n"
        return in_context

    def predict_output(
        self, in_context: str, input_data: pd.DataFrame, input_fields: str
    ):
        """Predict the output values for the given input data using the API."""
        query = (
            self.instruction
            + "\n\n"
            + in_context
            + "".join(
                f"<{col.replace('[Input] ', '')}>{input_data[col]}</{col.replace('[Input] ', '')}>\n"
                for col in input_fields
            )
        )
        self.query_example = query
        output = self.openai_request(query)
        return output

    def extract_fields(
        self, response: str, output_columns: list[str]
    ) -> dict[str, str]:
        """Extract fields from the response text based on output columns."""
        extracted = {}
        for col in output_columns:
            field = col.replace("[Output] ", "")
            match = re.search(f"<{field}>(.*?)</{field}>", response)
            extracted[col] = match.group(1) if match else ""
        return extracted

    # ─── Engine ───────────────────────────────────────────────────────────

    def _predict_and_extract(self, i: int) -> dict[str, str]:
        """Helper function to predict and extract fields for a single row."""
        prediction = self.predict_output(
            self.in_context, self.data.iloc[i], self.input_fields
        )
        extracted_fields = self.extract_fields(prediction, self.output_fields)
        return extracted_fields

    def batch_prediction(self, start_index: int, end_index: int):
        """Process a batch of predictions asynchronously."""
        with ThreadPoolExecutor() as executor:
            results = list(
                executor.map(self._predict_and_extract, range(start_index, end_index))
            )
        for i, extracted_fields in zip(range(start_index, end_index), results):
            for field_name in self.output_fields:
                self.data.at[i, field_name] = extracted_fields.get(field_name, "")

    def run(self):
        self.data, self.input_fields, self.output_fields = self.load_excel()
        self.in_context = self.derive_incontext(
            self.data, self.input_fields, self.output_fields
        )

        self.num_data = len(self.data)
        self.num_examples = len(self.data.dropna(subset=self.output_fields))

        tqdm_bar = tqdm(total=self.num_data - self.num_examples, leave=False)
        for start in range(self.num_examples, self.num_data, self.save_every):
            tqdm_bar.update(min(self.save_every, self.num_data - start))
            end = min(start + self.save_every, self.num_data)
            try:
                self.batch_prediction(start, end)
            except Exception as e:
                print(e)
            self.data.to_excel(self.out_file_path, index=False)
        self.data.to_excel(self.out_file_path, index=False)
        print(f"Results saved to {self.out_file_path}")