Hasan Iqbal commited on
Commit
195bb32
·
unverified ·
1 Parent(s): 0fe75ae

Added json fix in factool and factcheckgpt

Browse files
src/openfactcheck/solvers/webservice/factcheckgpt_utils/openai_api.py CHANGED
@@ -6,21 +6,25 @@ import openai
6
  client = None
7
 
8
 
 
 
 
 
9
  def init_client():
10
  global client
11
  if client is None:
12
- if openai.api_key is None and 'OPENAI_API_KEY' not in os.environ:
13
  print("openai_key not presented, delay to initialize.")
14
  return
15
  client = OpenAI()
16
 
17
 
18
  def request(
19
- user_inputs,
20
- model,
21
- system_role,
22
- temperature=1.0,
23
- return_all=False,
24
  ):
25
  init_client()
26
 
@@ -29,9 +33,7 @@ def request(
29
  elif type(user_inputs) == list:
30
  if all([type(x) == str for x in user_inputs]):
31
  chat_histories = [
32
- {
33
- "role": "user" if i % 2 == 0 else "assistant", "content": x
34
- } for i, x in enumerate(user_inputs)
35
  ]
36
  elif all([type(x) == dict for x in user_inputs]):
37
  chat_histories = user_inputs
@@ -39,31 +41,23 @@ def request(
39
  raise ValueError("Invalid input for OpenAI API calling")
40
  else:
41
  raise ValueError("Invalid input for OpenAI API calling")
42
-
43
 
44
  messages = [{"role": "system", "content": system_role}] + chat_histories
45
 
46
- response = client.chat.completions.create(
47
- model=model,
48
- messages=messages,
49
- temperature=temperature
50
- )
51
  if return_all:
52
  return response
53
- response_str = ''
54
  for choice in response.choices:
55
  response_str += choice.message.content
56
  return response_str
57
 
58
 
59
- def gpt(
60
- user_inputs,
61
- model,
62
- system_role,
63
- temperature=1.0,
64
- num_retries=3,
65
- waiting=1
66
- ):
67
  response = None
68
  for _ in range(num_retries):
69
  try:
 
6
  client = None
7
 
8
 
9
+ def _json_fix(output):
10
+ return output.replace("```json\n", "").replace("```", "")
11
+
12
+
13
  def init_client():
14
  global client
15
  if client is None:
16
+ if openai.api_key is None and "OPENAI_API_KEY" not in os.environ:
17
  print("openai_key not presented, delay to initialize.")
18
  return
19
  client = OpenAI()
20
 
21
 
22
  def request(
23
+ user_inputs,
24
+ model,
25
+ system_role,
26
+ temperature=1.0,
27
+ return_all=False,
28
  ):
29
  init_client()
30
 
 
33
  elif type(user_inputs) == list:
34
  if all([type(x) == str for x in user_inputs]):
35
  chat_histories = [
36
+ {"role": "user" if i % 2 == 0 else "assistant", "content": x} for i, x in enumerate(user_inputs)
 
 
37
  ]
38
  elif all([type(x) == dict for x in user_inputs]):
39
  chat_histories = user_inputs
 
41
  raise ValueError("Invalid input for OpenAI API calling")
42
  else:
43
  raise ValueError("Invalid input for OpenAI API calling")
 
44
 
45
  messages = [{"role": "system", "content": system_role}] + chat_histories
46
 
47
+ response = client.chat.completions.create(model=model, messages=messages, temperature=temperature)
48
+
49
+ # Fix the json format
50
+ response = _json_fix(response)
51
+
52
  if return_all:
53
  return response
54
+ response_str = ""
55
  for choice in response.choices:
56
  response_str += choice.message.content
57
  return response_str
58
 
59
 
60
+ def gpt(user_inputs, model, system_role, temperature=1.0, num_retries=3, waiting=1):
 
 
 
 
 
 
 
61
  response = None
62
  for _ in range(num_retries):
63
  try:
src/openfactcheck/solvers/webservice/factool_utils/chat_api.py CHANGED
@@ -72,6 +72,9 @@ class OpenAIChat:
72
  else:
73
  return None
74
 
 
 
 
75
  def _boolean_fix(self, output):
76
  return output.replace("true", "True").replace("false", "False")
77
 
@@ -166,7 +169,9 @@ class OpenAIChat:
166
  )
167
 
168
  preds = [
169
- self._type_check(self._boolean_fix(prediction.choices[0].message.content), expected_type)
 
 
170
  if prediction is not None
171
  else None
172
  for prediction in predictions
 
72
  else:
73
  return None
74
 
75
+ def _json_fix(self, output):
76
+ return output.replace("```json\n", "").replace("```", "")
77
+
78
  def _boolean_fix(self, output):
79
  return output.replace("true", "True").replace("false", "False")
80
 
 
169
  )
170
 
171
  preds = [
172
+ self._type_check(
173
+ self._boolean_fix(self._json_fix(prediction.choices[0].message.content)), expected_type
174
+ )
175
  if prediction is not None
176
  else None
177
  for prediction in predictions