Spaces:
Runtime error
Runtime error
| import ast | |
| import copy | |
| import toml | |
| from string import Template | |
| from pathlib import Path | |
| from flatdict import FlatDict | |
| import google.generativeai as genai | |
| from gen.utils import parse_first_json_snippet | |
| def determine_model_name(given_image=None): | |
| if given_image is None: | |
| return "gemini-pro" | |
| else: | |
| return "gemini-pro-vision" | |
| def construct_image_part(given_image): | |
| return { | |
| "mime_type": "image/jpeg", | |
| "data": given_image | |
| } | |
| def call_gemini(prompt="", API_KEY=None, given_text=None, given_image=None, generation_config=None, safety_settings=None): | |
| genai.configure(api_key=API_KEY) | |
| if generation_config is None: | |
| generation_config = { | |
| "temperature": 0.8, | |
| "top_p": 1, | |
| "top_k": 32, | |
| "max_output_tokens": 4096, | |
| } | |
| if safety_settings is None: | |
| safety_settings = [ | |
| { | |
| "category": "HARM_CATEGORY_HARASSMENT", | |
| "threshold": "BLOCK_NONE" | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_HATE_SPEECH", | |
| "threshold": "BLOCK_NONE" | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
| "threshold": "BLOCK_NONE" | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
| "threshold": "BLOCK_NONE" | |
| }, | |
| ] | |
| model_name = determine_model_name(given_image) | |
| model = genai.GenerativeModel(model_name=model_name, | |
| generation_config=generation_config, | |
| safety_settings=safety_settings) | |
| USER_PROMPT = prompt | |
| if given_text is not None: | |
| USER_PROMPT += f"""{prompt} | |
| ------------------------------------------------ | |
| {given_text} | |
| """ | |
| prompt_parts = [USER_PROMPT] | |
| if given_image is not None: | |
| prompt_parts.append(construct_image_part(given_image)) | |
| response = model.generate_content(prompt_parts) | |
| return response.text | |
| def try_out(prompt, given_text, gemini_api_key, given_image=None, retry_num=5): | |
| qna_json = None | |
| cur_retry = 0 | |
| while qna_json is None and cur_retry < retry_num: | |
| try: | |
| qna = call_gemini( | |
| prompt=prompt, | |
| given_text=given_text, | |
| given_image=given_image, | |
| API_KEY=gemini_api_key | |
| ) | |
| qna_json = parse_first_json_snippet(qna) | |
| except Exception as e: | |
| cur_retry = cur_retry + 1 | |
| print(f"......retry {e}") | |
| return qna_json | |
| def get_basic_qa(text, gemini_api_key, trucate=7000): | |
| prompts = toml.load(Path('.') / 'constants' / 'prompts.toml') | |
| basic_qa = try_out(prompts['basic_qa']['prompt'], text[:trucate], gemini_api_key=gemini_api_key) | |
| return basic_qa | |
| def get_deep_qa(text, basic_qa, gemini_api_key, trucate=7000): | |
| prompts = toml.load(Path('.') / 'constants' / 'prompts.toml') | |
| title = basic_qa['title'] | |
| qnas = copy.deepcopy(basic_qa['qna']) | |
| for idx, qna in enumerate(qnas): | |
| q = qna['question'] | |
| a_expert = qna['answers']['expert'] | |
| depth_search_prompt = Template(prompts['deep_qa']['prompt']).substitute( | |
| title=title, previous_question=q, previous_answer=a_expert, tone="in-depth" | |
| ) | |
| breath_search_prompt = Template(prompts['deep_qa']['prompt']).substitute( | |
| title=title, previous_question=q, previous_answer=a_expert, tone="broad" | |
| ) | |
| depth_search_response = {} | |
| breath_search_response = {} | |
| while 'follow up question' not in depth_search_response or \ | |
| 'answers' not in depth_search_response or \ | |
| 'eli5' not in depth_search_response['answers'] or \ | |
| 'expert' not in depth_search_response['answers']: | |
| depth_search_response = try_out(depth_search_prompt, text[:trucate], gemini_api_key=gemini_api_key) | |
| while 'follow up question' not in breath_search_response or \ | |
| 'answers' not in breath_search_response or \ | |
| 'eli5' not in breath_search_response['answers'] or \ | |
| 'expert' not in breath_search_response['answers']: | |
| breath_search_response = try_out(breath_search_prompt, text[:trucate], gemini_api_key=gemini_api_key) | |
| if depth_search_response is not None: | |
| qna['additional_depth_q'] = depth_search_response | |
| if breath_search_response is not None: | |
| qna['additional_breath_q'] = breath_search_response | |
| qna = FlatDict(qna) | |
| qna_tmp = copy.deepcopy(qna) | |
| for k in qna_tmp: | |
| value = qna.pop(k) | |
| qna[f'{idx}_{k}'] = value | |
| basic_qa.update(ast.literal_eval(str(qna))) | |
| return basic_qa |