Ki-Seki commited on
Commit
6e11e8c
Β·
1 Parent(s): 0269e89

chore: update

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. autotab.py +25 -12
  3. requirements.txt +1 -0
app.py CHANGED
@@ -51,7 +51,7 @@ inputs = [
51
  value='{"temperature": 0, "max_tokens": 128}',
52
  label="Generation Config in Dict",
53
  ),
54
- gr.Slider(value=0.01, minimum=0, maximum=10, label="Request Interval in Seconds"),
55
  gr.Slider(value=100, minimum=1, maximum=1000, step=1, label="Save Every N Steps"),
56
  gr.Textbox(
57
  value="sk-exhahhjfqyanmwewndukcqtrpegfdbwszkjucvcpajdufiah", label="API Key"
 
51
  value='{"temperature": 0, "max_tokens": 128}',
52
  label="Generation Config in Dict",
53
  ),
54
+ gr.Slider(value=0.1, minimum=0, maximum=10, label="Request Interval in Seconds"),
55
  gr.Slider(value=100, minimum=1, maximum=1000, step=1, label="Save Every N Steps"),
56
  gr.Textbox(
57
  value="sk-exhahhjfqyanmwewndukcqtrpegfdbwszkjucvcpajdufiah", label="API Key"
autotab.py CHANGED
@@ -1,8 +1,10 @@
1
  import re
2
  import time
 
3
 
4
  import openai
5
  import pandas as pd
 
6
  from tqdm import tqdm
7
 
8
 
@@ -42,8 +44,10 @@ class AutoTab:
42
 
43
  # ─── LLM ──────────────────────────────────────────────────────────────
44
 
 
45
  def openai_request(self, query: str) -> str:
46
  """Make a request to an OpenAI-format API."""
 
47
  client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
48
  response = client.chat.completions.create(
49
  model=self.model_name,
@@ -103,16 +107,23 @@ class AutoTab:
103
 
104
  # ─── Engine ───────────────────────────────────────────────────────────
105
 
 
 
 
 
 
 
 
 
106
  def batch_prediction(self, start_index: int, end_index: int):
107
- """Process a batch of predictions."""
108
- for i in range(start_index, end_index):
109
- prediction = self.predict_output(
110
- self.in_context, self.data.iloc[i], self.input_fields
111
  )
112
- extracted_fields = self.extract_fields(prediction, self.output_fields)
113
  for field_name in self.output_fields:
114
  self.data.at[i, field_name] = extracted_fields.get(field_name, "")
115
- time.sleep(self.request_interval)
116
 
117
  def run(self):
118
  self.data, self.input_fields, self.output_fields = self.load_excel()
@@ -123,12 +134,14 @@ class AutoTab:
123
  self.num_data = len(self.data)
124
  self.num_examples = len(self.data.dropna(subset=self.output_fields))
125
 
126
- for start_index in tqdm(
127
- range(self.num_examples, self.num_data, self.save_every),
128
- description="Processing batches",
129
- ):
130
- end_index = min(start_index + self.save_every, self.num_data)
131
- self.batch_prediction(start_index, end_index)
 
 
132
  self.data.to_excel(self.out_file_path, index=False)
133
  self.data.to_excel(self.out_file_path, index=False)
134
  print(f"Results saved to {self.out_file_path}")
 
1
  import re
2
  import time
3
+ from concurrent.futures import ThreadPoolExecutor
4
 
5
  import openai
6
  import pandas as pd
7
+ from tenacity import retry, stop_after_attempt, wait_random_exponential
8
  from tqdm import tqdm
9
 
10
 
 
44
 
45
  # ─── LLM ──────────────────────────────────────────────────────────────
46
 
47
+ @retry(wait=wait_random_exponential(min=20, max=60), stop=stop_after_attempt(6))
48
  def openai_request(self, query: str) -> str:
49
  """Make a request to an OpenAI-format API."""
50
+ time.sleep(self.request_interval)
51
  client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
52
  response = client.chat.completions.create(
53
  model=self.model_name,
 
107
 
108
  # ─── Engine ───────────────────────────────────────────────────────────
109
 
110
+ def _predict_and_extract(self, i: int) -> dict[str, str]:
111
+ """Helper function to predict and extract fields for a single row."""
112
+ prediction = self.predict_output(
113
+ self.in_context, self.data.iloc[i], self.input_fields
114
+ )
115
+ extracted_fields = self.extract_fields(prediction, self.output_fields)
116
+ return extracted_fields
117
+
118
  def batch_prediction(self, start_index: int, end_index: int):
119
+ """Process a batch of predictions asynchronously."""
120
+ with ThreadPoolExecutor() as executor:
121
+ results = list(
122
+ executor.map(self._predict_and_extract, range(start_index, end_index))
123
  )
124
+ for i, extracted_fields in zip(range(start_index, end_index), results):
125
  for field_name in self.output_fields:
126
  self.data.at[i, field_name] = extracted_fields.get(field_name, "")
 
127
 
128
  def run(self):
129
  self.data, self.input_fields, self.output_fields = self.load_excel()
 
134
  self.num_data = len(self.data)
135
  self.num_examples = len(self.data.dropna(subset=self.output_fields))
136
 
137
+ tqdm_bar = tqdm(range(self.num_examples, self.num_data, self.save_every))
138
+ for start in tqdm_bar:
139
+ tqdm_bar.update(start)
140
+ end = min(start + self.save_every, self.num_data)
141
+ try:
142
+ self.batch_prediction(start, end)
143
+ except Exception as e:
144
+ print(e)
145
  self.data.to_excel(self.out_file_path, index=False)
146
  self.data.to_excel(self.out_file_path, index=False)
147
  print(f"Results saved to {self.out_file_path}")
requirements.txt CHANGED
@@ -3,3 +3,4 @@ openai
3
  argparse
4
  openpyxl
5
  gradio
 
 
3
  argparse
4
  openpyxl
5
  gradio
6
+ tenacity