File size: 5,006 Bytes
55638b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import re

import openai
import pandas as pd
from tqdm import tqdm


class AutoTab:
    def __init__(
        self,
        in_file_path: str,
        out_file_path: str,
        max_examples: int,
        model_name: str,
        api_key: str,
        base_url: str,
        generation_config: dict,
        save_every: int,
        instruction: str,
    ):
        self.in_file_path = in_file_path
        self.out_file_path = out_file_path
        self.max_examples = max_examples
        self.model_name = model_name
        self.api_key = api_key
        self.base_url = base_url
        self.generation_config = generation_config
        self.save_every = save_every
        self.instruction = instruction

    # ─── IO ───────────────────────────────────────────────────────────────

    def load_excel(self) -> tuple[pd.DataFrame, list, list]:
        """Load the Excel file and identify input and output fields."""
        df = pd.read_excel(self.in_file_path)
        input_fields = [col for col in df.columns if col.startswith("[Input] ")]
        output_fields = [col for col in df.columns if col.startswith("[Output] ")]
        return df, input_fields, output_fields

    # ─── LLM ──────────────────────────────────────────────────────────────

    def openai_request(self, query: str) -> str:
        """Make a request to an OpenAI-format API."""
        client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
        response = client.chat.completions.create(
            model=self.model_name,
            messages=[{"role": "user", "content": query}],
            **self.generation_config,
        )
        str_response = response.choices[0].message.content.strip()
        return str_response

    # ─── In-Context Learning ──────────────────────────────────────────────

    def derive_incontext(
        self, data: pd.DataFrame, input_columns: list[str], output_columns: list[str]
    ) -> str:
        """Derive the in-context prompt with angle brackets."""
        n = min(self.max_examples, len(data.dropna(subset=output_columns)))
        in_context = ""
        for i in range(n):
            in_context += "".join(
                f"<{col.replace('[Input] ', '')}>{data[col].iloc[i]}</{col.replace('[Input] ', '')}>\n"
                for col in input_columns
            )
            in_context += "".join(
                f"<{col.replace('[Output] ', '')}>{data[col].iloc[i]}</{col.replace('[Output] ', '')}>\n"
                for col in output_columns
            )
            in_context += "\n"
        self.in_context = in_context
        return in_context

    def predict_output(
        self, in_context: str, input_data: pd.DataFrame, input_fields: str
    ):
        """Predict the output values for the given input data using the API."""
        query = (
            self.instruction
            + "\n\n"
            + in_context
            + "".join(
                f"<{col.replace('[Input] ', '')}>{input_data[col]}</{col.replace('[Input] ', '')}>\n"
                for col in input_fields
            )
        )
        self.query_example = query
        output = self.openai_request(query)
        return output

    def extract_fields(
        self, response: str, output_columns: list[str]
    ) -> dict[str, str]:
        """Extract fields from the response text based on output columns."""
        extracted = {}
        for col in output_columns:
            field = col.replace("[Output] ", "")
            match = re.search(f"<{field}>(.*?)</{field}>", response)
            extracted[col] = match.group(1) if match else ""
        return extracted

    # ─── Engine ───────────────────────────────────────────────────────────

    def run(self):
        data, input_fields, output_fields = self.load_excel()
        in_context = self.derive_incontext(data, input_fields, output_fields)

        num_existed_examples = len(data.dropna(subset=output_fields))

        for i in tqdm(range(num_existed_examples, len(data))):
            prediction = self.predict_output(in_context, data.iloc[i], input_fields)
            extracted_fields = self.extract_fields(prediction, output_fields)
            for field_name in output_fields:
                data.at[i, field_name] = extracted_fields.get(field_name, "")
            if i % self.save_every == 0:
                data.to_excel(self.out_file_path, index=False)
        self.data = data
        data.to_excel(self.out_file_path, index=False)
        print(f"Results saved to {self.out_file_path}")