Last commit not found
raw
history blame
7.54 kB
import os
import sys
import inspect
import json
from json import JSONDecodeError
import tiktoken
import random
import google.generativeai as palm
currentdir = os.path.dirname(os.path.abspath(
inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir)
from prompt_catalog import PromptCatalog
from general_utils import num_tokens_from_string
"""
DEPRECATED:
Safety setting regularly block a response, so set to 4 to disable
class HarmBlockThreshold(Enum):
HARM_BLOCK_THRESHOLD_UNSPECIFIED = 0
BLOCK_LOW_AND_ABOVE = 1
BLOCK_MEDIUM_AND_ABOVE = 2
BLOCK_ONLY_HIGH = 3
BLOCK_NONE = 4
"""
SAFETY_SETTINGS = [
{
"category": "HARM_CATEGORY_DEROGATORY",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_TOXICITY",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_VIOLENCE",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUAL",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_MEDICAL",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_DANGEROUS",
"threshold": "BLOCK_NONE",
},
]
PALM_SETTINGS = {
'model': 'models/text-bison-001',
'temperature': 0,
'candidate_count': 1,
'top_k': 40,
'top_p': 0.95,
'max_output_tokens': 8000,
'stop_sequences': [],
'safety_settings': SAFETY_SETTINGS,
}
PALM_SETTINGS_REDO = {
'model': 'models/text-bison-001',
'temperature': 0.05,
'candidate_count': 1,
'top_k': 40,
'top_p': 0.95,
'max_output_tokens': 8000,
'stop_sequences': [],
'safety_settings': SAFETY_SETTINGS,
}
def OCR_to_dict_PaLM(logger, OCR, prompt_version, VVE):
try:
logger.info(f'Length of OCR raw -- {len(OCR)}')
except:
print(f'Length of OCR raw -- {len(OCR)}')
# prompt = PROMPT_PaLM_UMICH_skeleton_all_asia(OCR, in_list, out_list) # must provide examples to PaLM differently than for chatGPT, at least 2 examples
Prompt = PromptCatalog(OCR)
if prompt_version in ['prompt_v2_palm2']:
version = 'v2'
prompt = Prompt.prompt_v2_palm2(OCR)
elif prompt_version in ['prompt_v1_palm2',]:
version = 'v1'
# create input: output: for PaLM
# Find a similar example from the domain knowledge
domain_knowledge_example = VVE.query_db(OCR, 4)
similarity= VVE.get_similarity()
domain_knowledge_example_string = json.dumps(domain_knowledge_example)
in_list, out_list = create_OCR_analog_for_input(domain_knowledge_example)
prompt = Prompt.prompt_v1_palm2(in_list, out_list, OCR)
elif prompt_version in ['prompt_v1_palm2_noDomainKnowledge',]:
version = 'v1'
prompt = Prompt.prompt_v1_palm2_noDomainKnowledge(OCR)
else:
version = 'custom'
prompt, n_fields, xlsx_headers = Prompt.prompt_v2_custom(prompt_version, OCR=OCR, is_palm=True)
# raise
nt = num_tokens_from_string(prompt, "cl100k_base")
# try:
logger.info(f'Prompt token length --- {nt}')
# except:
# print(f'Prompt token length --- {nt}')
do_use_SOP = False ########
if do_use_SOP:
'''TODO: Check back later to see if LangChain will support PaLM'''
# logger.info(f'Waiting for PaLM API call --- Using StructuredOutputParser')
# response = structured_output_parser(OCR, prompt, logger)
# return response['Dictionary']
pass
else:
# try:
logger.info(f'Waiting for PaLM 2 API call')
# except:
# print(f'Waiting for PaLM 2 API call --- Content')
# safety_thresh = 4
# PaLM_settings = {'model': 'models/text-bison-001','temperature': 0,'candidate_count': 1,'top_k': 40,'top_p': 0.95,'max_output_tokens': 8000,'stop_sequences': [],
# 'safety_settings': [{"category":"HARM_CATEGORY_DEROGATORY","threshold":safety_thresh},{"category":"HARM_CATEGORY_TOXICITY","threshold":safety_thresh},{"category":"HARM_CATEGORY_VIOLENCE","threshold":safety_thresh},{"category":"HARM_CATEGORY_SEXUAL","threshold":safety_thresh},{"category":"HARM_CATEGORY_MEDICAL","threshold":safety_thresh},{"category":"HARM_CATEGORY_DANGEROUS","threshold":safety_thresh}],}
response = palm.generate_text(prompt=prompt, **PALM_SETTINGS)
if response and response.result:
if isinstance(response.result, (str, bytes)):
response_valid = check_and_redo_JSON(response, logger, version)
else:
response_valid = {}
else:
response_valid = {}
logger.info(f'Candidate JSON\n{response.result}')
return response_valid, nt
def check_and_redo_JSON(response, logger, version):
try:
response_valid = json.loads(response.result)
logger.info(f'Response --- First call passed')
return response_valid
except JSONDecodeError:
try:
response_valid = json.loads(response.result.strip('```').replace('json\n', '', 1).replace('json', '', 1))
logger.info(f'Response --- Manual removal of ```json succeeded')
return response_valid
except:
logger.info(f'Response --- First call failed. Redo...')
Prompt = PromptCatalog()
if version == 'v1':
prompt_redo = Prompt.prompt_palm_redo_v1(response.result)
elif version == 'v2':
prompt_redo = Prompt.prompt_palm_redo_v2(response.result)
elif version == 'custom':
prompt_redo = Prompt.prompt_v2_custom_redo(response.result, is_palm=True)
# prompt_redo = PROMPT_PaLM_Redo(response.result)
try:
response = palm.generate_text(prompt=prompt_redo, **PALM_SETTINGS)
response_valid = json.loads(response.result)
logger.info(f'Response --- Second call passed')
return response_valid
except JSONDecodeError:
logger.info(f'Response --- Second call failed. Final redo. Temperature changed to 0.05')
try:
response = palm.generate_text(prompt=prompt_redo, **PALM_SETTINGS_REDO)
response_valid = json.loads(response.result)
logger.info(f'Response --- Third call passed')
return response_valid
except JSONDecodeError:
return None
def create_OCR_analog_for_input(domain_knowledge_example):
in_list = []
out_list = []
# Iterate over the domain_knowledge_example (list of dictionaries)
for row_dict in domain_knowledge_example:
# Convert the dictionary to a JSON string and add it to the out_list
domain_knowledge_example_string = json.dumps(row_dict)
out_list.append(domain_knowledge_example_string)
# Create a single string from all values in the row_dict
row_text = '||'.join(str(v) for v in row_dict.values())
# Split the row text by '||', shuffle the parts, and then re-join with a single space
parts = row_text.split('||')
random.shuffle(parts)
shuffled_text = ' '.join(parts)
# Add the shuffled_text to the in_list
in_list.append(shuffled_text)
return in_list, out_list
def strip_problematic_chars(s):
return ''.join(c for c in s if c.isprintable())