|
import argparse |
|
import os |
|
from dotenv import load_dotenv |
|
|
|
from langchain.globals import set_debug |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_core.output_parsers import StrOutputParser |
|
|
|
from lib.repository import download_github_repo |
|
from lib.loader import load_files |
|
from lib.chain import create_retriever, create_qa_chain |
|
from lib.utils import read_prompt, load_LLM, select_model |
|
from lib.models import MODELS_MAP |
|
|
|
|
|
|
|
def main(): |
|
|
|
model_name = select_model() |
|
model_info = MODELS_MAP[model_name] |
|
|
|
|
|
parser = argparse.ArgumentParser(description="GitHub Repo QA CLI Application") |
|
parser.add_argument("repo_url", type=str, help="URL of the GitHub repository") |
|
args = parser.parse_args() |
|
|
|
|
|
repo_url = args.repo_url |
|
repo_name = repo_url.split("/")[-1].replace(".git", "") |
|
|
|
|
|
base_dir = os.path.dirname(os.path.abspath(__file__)) |
|
repo_dir = os.path.join(base_dir, "data", repo_name) |
|
db_dir = os.path.join(base_dir, "data", "db") |
|
prompt_templates_dir = os.path.join(base_dir, "prompt_templates") |
|
|
|
|
|
print(f"Downloading repository from {repo_url}...") |
|
download_github_repo(repo_url, repo_dir) |
|
|
|
|
|
prompts_text = { |
|
"initial_prompt": read_prompt(os.path.join(prompt_templates_dir, 'initial_prompt.txt')), |
|
"evaluation_prompt": read_prompt(os.path.join(prompt_templates_dir, 'evaluation_prompt.txt')), |
|
} |
|
|
|
|
|
print(f"Loading documents from {repo_dir}...") |
|
document_chunks = load_files(repository_path=repo_dir) |
|
print(f"Created chunks length is: {len(document_chunks)}") |
|
|
|
|
|
print(f"Creating retrieval QA chain using {model_name}...") |
|
llm = load_LLM(model_name) |
|
retriever = create_retriever(model_name, db_dir, document_chunks) |
|
qa_chain = create_qa_chain(llm, retriever, prompts_text) |
|
|
|
print("You can start asking questions. Type 'exit' to quit.") |
|
while True: |
|
question = input("Question: ") |
|
if question.lower() == "exit": |
|
break |
|
answer = qa_chain.invoke(question) |
|
print(f"Answer: {answer['output']}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |