Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
"""Search enhancement proxy.""" | |
import argparse | |
import json | |
import os | |
import pytoml | |
from loguru import logger | |
from .llm_client import ChatClient | |
class SourceGraphProxy: | |
"""A class to serve as a proxy for interacting with the Source Graph. | |
Args: | |
config_path (dict): Path to the configuration file. | |
topk (int, optional): Top K results to consider from the search. Defaults to 1. # noqa E501 | |
language (str, optional): Language for the system prompts - 'zh' for Chinese and 'en' for English. Defaults to 'zh'. # noqa E501 | |
Attributes: | |
config_path (str): The path of the configuration file. | |
sg_config (dict): Configuration settings for sourcegraph search. | |
topk (int): Top K results to consider from the search. | |
language (str): Language for the system prompts. | |
CHOICE_TEMPLATE (str): Template string for generating choice based on selected language. # noqa E501 | |
KEYWORDS_TEMPLATE (str): Template string for generating keywords based on selected language. # noqa E501 | |
""" | |
def __init__(self, | |
config_path: dict, | |
topk=1, | |
language: str = 'zh') -> None: | |
"""Init searcher with config.""" | |
self.config_path = config_path | |
self.sg_config = None | |
with open(self.config_path, encoding='utf8') as f: | |
config = pytoml.load(f) | |
self.sg_config = config['sg_search'] | |
self.topk = topk | |
self.language = language | |
if self.language == 'zh': | |
self.CHOICE_TEMPLATE = '“{}”\n请仔细阅读以上问题,请问应该查询以下哪个开源项目:\n' # noqa E501 | |
self.KEYWORDS_TEMPLATE = '“{}”\n请仔细阅读以上问题,提取其中可用作搜索引擎的关键字,关键字之间,分隔,不要解释。' # noqa E501 | |
else: | |
self.CHOICE_TEMPLATE = '"{}"\nPlease read the above question carefully, which of the following open-source projects should this question refer to: \n' # noqa E501 | |
self.KEYWORDS_TEMPLATE = '"{}"\nPlease read the above questions carefully, extract the keywords which can be used as search engines, between keywords, separate, do not explain.' # noqa E501 | |
def command(self, txt: str): | |
"""Executes a shell command and returns its output. | |
Args: | |
txt (str): Command to be executed in the shell. | |
Returns: | |
str: Output of the shell command execution. | |
""" | |
logger.debug('cmd: {}'.format(txt)) | |
cmd = os.popen(txt) | |
return cmd.read().rstrip().lstrip() | |
def extract_sg_result(self, jsonstr): | |
"""Extracts the desired data from the source graph result. | |
Args: | |
jsonstr (str): JSON string containing source graph search result. | |
Returns: | |
list: List of dictionaries each contains 'filepath' and 'content' of the files returned by source graph. # noqa E501 | |
""" | |
ret = [] | |
try: | |
root = json.loads(jsonstr) | |
results = root['Results'] | |
for result in results: | |
if 'FileMatch' != result['__typename']: | |
continue | |
content = result['file']['content'] | |
path = result['file']['path'] | |
ret.append({'filepath': path, 'content': content}) | |
if len(ret) >= self.topk: | |
break | |
except Exception as e: | |
logger.warning('{} when source graph parse {}'.format( | |
str(e), jsonstr)) | |
return ret | |
def choose_repo(self, llm_client, question, groupname): | |
"""Interactively assists user to select a repository for search based | |
on user's question. | |
Args: | |
llm_client: Client instance for LLM. | |
question (str): User's question. | |
groupname (str): Name of the user's group. | |
Returns: | |
str: The ID of selected repository. | |
""" | |
prompt = self.CHOICE_TEMPLATE.format(question) | |
keys = self.sg_config.keys() | |
skip = ['binary_src_path', 'src_access_token'] | |
repos = {} | |
for key in keys: | |
if key in skip: | |
continue | |
introduction = self.sg_config[key]['introduction'] | |
prompt += f'* {key} {introduction}\n' | |
repos[key] = self.sg_config[key] | |
prompt += '* none ' | |
choice = llm_client.generate_response(prompt=prompt, | |
backend='remote').strip() | |
target_repo_id = None | |
for key in repos.keys(): | |
if key in choice: | |
target_repo_id = repos[key]['github_repo_id'] | |
break | |
return target_repo_id | |
def search(self, llm_client, question, groupname): | |
"""Performs a search operation in the selected repository based on the | |
user's question. | |
Args: | |
llm_client: Client instance for LLM. | |
question (str): User's question. | |
groupname (str): Name of the user's group. | |
Returns: | |
str: Search result from source graph in JSON format. | |
""" | |
repo_id = self.choose_repo(llm_client, question, groupname) | |
if repo_id is None: | |
logger.warning('cannot choose repo_id') | |
return '' | |
ENV = 'export SRC_ACCESS_TOKEN="{}" && '.format( | |
self.sg_config['src_access_token']) | |
BINARY = self.sg_config['binary_src_path'] | |
prompt = self.KEYWORDS_TEMPLATE.format(question) | |
entities = [] | |
entity_str = '' | |
try: | |
entity_str = llm_client.generate_response(prompt=prompt) | |
entities = [item for item in entity_str.split(',') if item.strip()] | |
except Exception as e: | |
logger.error('parse {} failed {}.'.format(entity_str, str(e))) | |
# return '' | |
entities = [] | |
search_items = [] | |
for entity in entities: | |
# search doc and source code based on entities | |
# search -json 'repo:open-compass/opencompass summarizers' | |
cmd_doc = '''{} search -json 'repo:{} lang:MarkDown {}' '''.format( | |
BINARY, repo_id, entity) | |
cmd_return = self.command(ENV + cmd_doc) | |
search_items += self.extract_sg_result(cmd_return) | |
cmd_python = '''{} search -json 'repo:{} lang:Python {}' '''.format( # noqa E501 | |
BINARY, repo_id, entity) | |
cmd_return = self.command(ENV + cmd_python) | |
search_items += self.extract_sg_result(cmd_return) | |
search_text = json.dumps(search_items, ensure_ascii=False, indent=2) | |
return search_text | |
def parse_args(): | |
"""Parses command line arguments.""" | |
parser = argparse.ArgumentParser(description='Source graph proxy search') | |
parser.add_argument( | |
'--config_path', | |
default='config.ini', | |
help= # noqa E251 | |
'Source graph proxy configuration path. Default value is config.ini') | |
args = parser.parse_args() | |
return args | |
if __name__ == '__main__': | |
"""Test search.""" | |
logger.add('logs/sg_search.log', rotation='4MB') | |
args = parse_args() | |
llm = ChatClient(config_path=args.config_path) | |
sg = SourceGraphProxy(config_path=args.config_path) | |
context = sg.search(llm, | |
question='请问triviaqa 5shot结果怎么在summarizer里输出呢', | |
groupname='opencompass') | |
print(context) | |