JayWadekar commited on
Commit
699fe5f
·
1 Parent(s): 60b6fa4

Added code to ouput references

Browse files
Files changed (2) hide show
  1. app.py +0 -3
  2. rag.py +35 -5
app.py CHANGED
@@ -2,9 +2,6 @@
2
  # the gwIAS search pipline
3
  # using Langchain and deployed with Gradio
4
 
5
- # Thanks to Pablo Villanueva Domingo for sharing his CAMELS template
6
- # https://huggingface.co/spaces/PabloVD/CAMELSDocBot
7
-
8
  from rag import RAG, load_docs
9
  from langchain_community.embeddings import HuggingFaceInstructEmbeddings
10
  from langchain.chat_models import ChatOpenAI
 
2
  # the gwIAS search pipline
3
  # using Langchain and deployed with Gradio
4
 
 
 
 
5
  from rag import RAG, load_docs
6
  from langchain_community.embeddings import HuggingFaceInstructEmbeddings
7
  from langchain.chat_models import ChatOpenAI
rag.py CHANGED
@@ -23,12 +23,37 @@ def load_docs():
23
  # Load, chunk and index the contents of the blog.
24
  loader = WebBaseLoader(urls)
25
  docs = loader.load()
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  return docs
28
 
 
 
 
 
 
 
 
 
29
  # Join content pages for processing
30
  def format_docs(docs):
31
- return "\n\n".join(doc.page_content for doc in docs)
 
 
 
 
 
 
32
 
33
  # Create a RAG chain
34
  def RAG(llm, docs, embeddings):
@@ -45,6 +70,14 @@ def RAG(llm, docs, embeddings):
45
 
46
  # Prompt basis example for RAG systems
47
  prompt = hub.pull("rlm/rag-prompt")
 
 
 
 
 
 
 
 
48
 
49
  # Create the chain
50
  rag_chain = (
@@ -54,7 +87,4 @@ def RAG(llm, docs, embeddings):
54
  | StrOutputParser()
55
  )
56
 
57
- return rag_chain
58
-
59
-
60
- # Debugging push
 
23
  # Load, chunk and index the contents of the blog.
24
  loader = WebBaseLoader(urls)
25
  docs = loader.load()
26
+
27
+ # Add source URLs as document names for reference
28
+ for i, doc in enumerate(docs):
29
+ if 'source' in doc.metadata:
30
+ doc.metadata['name'] = doc.metadata['source']
31
+ else:
32
+ doc.metadata['name'] = f"Document {i+1}"
33
+
34
+ print(f"Loaded {len(docs)} documents:")
35
+ for doc in docs:
36
+ print(f" - {doc.metadata.get('name')}")
37
 
38
  return docs
39
 
40
+ def extract_reference(url):
41
+ """Extract a reference keyword from the GitHub URL"""
42
+ if "blob/main" in url:
43
+ return url.split("blob/main/")[-1]
44
+ elif "tree/main" in url:
45
+ return url.split("tree/main/")[-1] or "root"
46
+ return url
47
+
48
  # Join content pages for processing
49
  def format_docs(docs):
50
+ formatted_docs = []
51
+ for doc in docs:
52
+ source = doc.metadata.get('source', 'Unknown source')
53
+ reference = f"[{extract_reference(source)}]"
54
+ content = doc.page_content
55
+ formatted_docs.append(f"{content}\n\nReference: {reference}")
56
+ return "\n\n---\n\n".join(formatted_docs)
57
 
58
  # Create a RAG chain
59
  def RAG(llm, docs, embeddings):
 
70
 
71
  # Prompt basis example for RAG systems
72
  prompt = hub.pull("rlm/rag-prompt")
73
+ # Adding custom instructions to the prompt
74
+ template = prompt.messages[0].prompt.template
75
+ template_parts = template.split("\nQuestion: {question}")
76
+ if len(template_parts) == 2:
77
+ print("Error: Template does not contain the expected format.")
78
+ combined_template = template_parts[0] +\
79
+ " Include the reference IDs in square brackets when citing specific information." + template_parts[1]
80
+ prompt.messages[0].prompt.template = combined_template
81
 
82
  # Create the chain
83
  rag_chain = (
 
87
  | StrOutputParser()
88
  )
89
 
90
+ return rag_chain