Joshua Sundance Bailey commited on
Commit
21eccfc
·
1 Parent(s): 87d6984

create llm_resources.py

Browse files
langchain-streamlit-demo/app.py CHANGED
@@ -1,31 +1,16 @@
1
  from datetime import datetime
2
- from tempfile import NamedTemporaryFile
3
  from typing import Tuple, List, Dict, Any, Union
4
 
5
  import anthropic
6
  import langsmith.utils
7
  import openai
8
  import streamlit as st
9
- from langchain.callbacks.base import BaseCallbackHandler
10
  from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
11
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
12
- from langchain.chains import RetrievalQA
13
- from langchain.chains.llm import LLMChain
14
- from langchain.chat_models import (
15
- AzureChatOpenAI,
16
- ChatAnthropic,
17
- ChatAnyscale,
18
- ChatOpenAI,
19
- )
20
- from langchain.document_loaders import PyPDFLoader
21
- from langchain.embeddings import OpenAIEmbeddings
22
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
23
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
24
- from langchain.retrievers import BM25Retriever, EnsembleRetriever
25
  from langchain.schema.document import Document
26
  from langchain.schema.retriever import BaseRetriever
27
- from langchain.text_splitter import RecursiveCharacterTextSplitter
28
- from langchain.vectorstores import FAISS
29
  from langsmith.client import Client
30
  from streamlit_feedback import streamlit_feedback
31
 
@@ -52,8 +37,7 @@ from defaults import (
52
  DEFAULT_CHUNK_OVERLAP,
53
  DEFAULT_RETRIEVER_K,
54
  )
55
- from qagen import get_rag_qa_gen_chain
56
- from summarize import get_rag_summarization_chain
57
 
58
  __version__ = "0.0.13"
59
 
@@ -84,61 +68,29 @@ st_init_null(
84
  "trace_link",
85
  )
86
 
87
- # --- Memory ---
88
  STMEMORY = StreamlitChatMessageHistory(key="langchain_messages")
89
  MEMORY = ConversationBufferMemory(
90
  chat_memory=STMEMORY,
91
  return_messages=True,
92
  memory_key="chat_history",
93
  )
94
-
95
-
96
- # --- Callbacks ---
97
- class StreamHandler(BaseCallbackHandler):
98
- def __init__(self, container, initial_text=""):
99
- self.container = container
100
- self.text = initial_text
101
-
102
- def on_llm_new_token(self, token: str, **kwargs) -> None:
103
- self.text += token
104
- self.container.markdown(self.text)
105
-
106
-
107
  RUN_COLLECTOR = RunCollectorCallbackHandler()
108
 
109
 
110
  @st.cache_data
111
- def get_texts_and_retriever(
112
  uploaded_file_bytes: bytes,
113
  chunk_size: int = DEFAULT_CHUNK_SIZE,
114
  chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
115
  k: int = DEFAULT_RETRIEVER_K,
116
  ) -> Tuple[List[Document], BaseRetriever]:
117
- with NamedTemporaryFile() as temp_file:
118
- temp_file.write(uploaded_file_bytes)
119
- temp_file.seek(0)
120
-
121
- loader = PyPDFLoader(temp_file.name)
122
- documents = loader.load()
123
- text_splitter = RecursiveCharacterTextSplitter(
124
- chunk_size=chunk_size,
125
- chunk_overlap=chunk_overlap,
126
- )
127
- texts = text_splitter.split_documents(documents)
128
- embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
129
-
130
- bm25_retriever = BM25Retriever.from_documents(texts)
131
- bm25_retriever.k = k
132
-
133
- faiss_vectorstore = FAISS.from_documents(texts, embeddings)
134
- faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": k})
135
-
136
- ensemble_retriever = EnsembleRetriever(
137
- retrievers=[bm25_retriever, faiss_retriever],
138
- weights=[0.5, 0.5],
139
- )
140
-
141
- return texts, ensemble_retriever
142
 
143
 
144
  # --- Sidebar ---
@@ -351,46 +303,21 @@ with sidebar:
351
 
352
 
353
  # --- LLM Instantiation ---
354
- if provider_api_key:
355
- if st.session_state.provider == "OpenAI":
356
- st.session_state.llm = ChatOpenAI(
357
- model_name=model,
358
- openai_api_key=provider_api_key,
359
- temperature=temperature,
360
- streaming=True,
361
- max_tokens=max_tokens,
362
- )
363
-
364
- elif st.session_state.provider == "Anthropic":
365
- st.session_state.llm = ChatAnthropic(
366
- model=model,
367
- anthropic_api_key=provider_api_key,
368
- temperature=temperature,
369
- streaming=True,
370
- max_tokens_to_sample=max_tokens,
371
- )
372
-
373
- elif st.session_state.provider == "Anyscale Endpoints":
374
- st.session_state.llm = ChatAnyscale(
375
- model_name=model,
376
- anyscale_api_key=provider_api_key,
377
- temperature=temperature,
378
- streaming=True,
379
- max_tokens=max_tokens,
380
- )
381
-
382
- elif AZURE_AVAILABLE and st.session_state.provider == "Azure OpenAI":
383
- st.session_state.llm = AzureChatOpenAI(
384
- openai_api_base=AZURE_OPENAI_BASE_URL,
385
- openai_api_version=AZURE_OPENAI_API_VERSION,
386
- deployment_name=AZURE_OPENAI_DEPLOYMENT_NAME,
387
- openai_api_key=AZURE_OPENAI_API_KEY,
388
- openai_api_type="azure",
389
- model_version=AZURE_OPENAI_MODEL_VERSION,
390
- temperature=temperature,
391
- streaming=True,
392
- max_tokens=max_tokens,
393
- )
394
 
395
  # --- Chat History ---
396
  if len(STMEMORY.messages) == 0:
@@ -451,38 +378,15 @@ if st.session_state.llm:
451
  stream_handler = StreamHandler(message_placeholder)
452
  callbacks.append(stream_handler)
453
 
454
- def get_rag_runnable():
455
- if document_chat_chain_type == "Q&A Generation":
456
- return get_rag_qa_gen_chain(
457
- st.session_state.retriever,
458
- st.session_state.llm,
459
- )
460
- elif document_chat_chain_type == "Summarization":
461
- return get_rag_summarization_chain(
462
- prompt,
463
- st.session_state.retriever,
464
- st.session_state.llm,
465
- )
466
- else:
467
- return RetrievalQA.from_chain_type(
468
- llm=st.session_state.llm,
469
- chain_type=document_chat_chain_type,
470
- retriever=st.session_state.retriever,
471
- memory=MEMORY,
472
- output_key="output_text",
473
- ) | (lambda output: output["output_text"])
474
-
475
- st.session_state.chain = (
476
- get_rag_runnable()
477
- if use_document_chat
478
- else LLMChain(
479
- prompt=chat_prompt,
480
- llm=st.session_state.llm,
481
- memory=MEMORY,
482
- )
483
- | (lambda output: output["text"])
484
  )
485
 
 
486
  try:
487
  full_response = st.session_state.chain.invoke(prompt, config)
488
 
@@ -492,6 +396,7 @@ if st.session_state.llm:
492
  icon="❌",
493
  )
494
 
 
495
  if full_response is not None:
496
  message_placeholder.markdown(full_response)
497
 
@@ -507,6 +412,8 @@ if st.session_state.llm:
507
  ).url
508
  except langsmith.utils.LangSmithError:
509
  st.session_state.trace_link = None
 
 
510
  if st.session_state.trace_link:
511
  with sidebar:
512
  st.markdown(
@@ -550,10 +457,6 @@ if st.session_state.llm:
550
  score=score,
551
  comment=feedback.get("text"),
552
  )
553
- # feedback = {
554
- # "feedback_id": str(feedback_record.id),
555
- # "score": score,
556
- # }
557
  st.toast("Feedback recorded!", icon="📝")
558
  else:
559
  st.warning("Invalid feedback score.")
 
1
  from datetime import datetime
 
2
  from typing import Tuple, List, Dict, Any, Union
3
 
4
  import anthropic
5
  import langsmith.utils
6
  import openai
7
  import streamlit as st
 
8
  from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
9
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
 
 
 
 
 
 
 
 
 
 
10
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
11
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
 
12
  from langchain.schema.document import Document
13
  from langchain.schema.retriever import BaseRetriever
 
 
14
  from langsmith.client import Client
15
  from streamlit_feedback import streamlit_feedback
16
 
 
37
  DEFAULT_CHUNK_OVERLAP,
38
  DEFAULT_RETRIEVER_K,
39
  )
40
+ from llm_resources import get_runnable, get_llm, get_texts_and_retriever, StreamHandler
 
41
 
42
  __version__ = "0.0.13"
43
 
 
68
  "trace_link",
69
  )
70
 
71
+ # --- LLM globals ---
72
  STMEMORY = StreamlitChatMessageHistory(key="langchain_messages")
73
  MEMORY = ConversationBufferMemory(
74
  chat_memory=STMEMORY,
75
  return_messages=True,
76
  memory_key="chat_history",
77
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  RUN_COLLECTOR = RunCollectorCallbackHandler()
79
 
80
 
81
  @st.cache_data
82
+ def get_texts_and_retriever_cacheable_wrapper(
83
  uploaded_file_bytes: bytes,
84
  chunk_size: int = DEFAULT_CHUNK_SIZE,
85
  chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
86
  k: int = DEFAULT_RETRIEVER_K,
87
  ) -> Tuple[List[Document], BaseRetriever]:
88
+ return get_texts_and_retriever(
89
+ uploaded_file_bytes=uploaded_file_bytes,
90
+ chunk_size=chunk_size,
91
+ chunk_overlap=chunk_overlap,
92
+ k=k,
93
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
 
96
  # --- Sidebar ---
 
303
 
304
 
305
  # --- LLM Instantiation ---
306
+ llm = get_llm(
307
+ provider=st.session_state.provider,
308
+ model=model,
309
+ provider_api_key=provider_api_key,
310
+ temperature=temperature,
311
+ max_tokens=max_tokens,
312
+ azure_available=AZURE_AVAILABLE,
313
+ azure_dict={
314
+ "AZURE_OPENAI_BASE_URL": AZURE_OPENAI_BASE_URL,
315
+ "AZURE_OPENAI_API_VERSION": AZURE_OPENAI_API_VERSION,
316
+ "AZURE_OPENAI_DEPLOYMENT_NAME": AZURE_OPENAI_DEPLOYMENT_NAME,
317
+ "AZURE_OPENAI_API_KEY": AZURE_OPENAI_API_KEY,
318
+ "AZURE_OPENAI_MODEL_VERSION": AZURE_OPENAI_MODEL_VERSION,
319
+ },
320
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  # --- Chat History ---
323
  if len(STMEMORY.messages) == 0:
 
378
  stream_handler = StreamHandler(message_placeholder)
379
  callbacks.append(stream_handler)
380
 
381
+ st.session_state.chain = get_runnable(
382
+ use_document_chat,
383
+ document_chat_chain_type,
384
+ st.session_state.llm,
385
+ st.session_state.retriever,
386
+ MEMORY,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  )
388
 
389
+ # --- LLM call ---
390
  try:
391
  full_response = st.session_state.chain.invoke(prompt, config)
392
 
 
396
  icon="❌",
397
  )
398
 
399
+ # --- Display output ---
400
  if full_response is not None:
401
  message_placeholder.markdown(full_response)
402
 
 
412
  ).url
413
  except langsmith.utils.LangSmithError:
414
  st.session_state.trace_link = None
415
+
416
+ # --- LangSmith Trace Link ---
417
  if st.session_state.trace_link:
418
  with sidebar:
419
  st.markdown(
 
457
  score=score,
458
  comment=feedback.get("text"),
459
  )
 
 
 
 
460
  st.toast("Feedback recorded!", icon="📝")
461
  else:
462
  st.warning("Invalid feedback score.")
langchain-streamlit-demo/llm_resources.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tempfile import NamedTemporaryFile
2
+ from typing import Tuple, List
3
+
4
+ from langchain import LLMChain, FAISS
5
+ from langchain.callbacks.base import BaseCallbackHandler
6
+ from langchain.chains import RetrievalQA
7
+ from langchain.chat_models import (
8
+ AzureChatOpenAI,
9
+ ChatOpenAI,
10
+ ChatAnthropic,
11
+ ChatAnyscale,
12
+ )
13
+ from langchain.document_loaders import PyPDFLoader
14
+ from langchain.embeddings import OpenAIEmbeddings
15
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
16
+ from langchain.schema import Document, BaseRetriever
17
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
18
+
19
+ from app import chat_prompt, prompt, openai_api_key
20
+ from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
21
+ from qagen import get_rag_qa_gen_chain
22
+ from summarize import get_rag_summarization_chain
23
+
24
+
25
+ def get_runnable(
26
+ use_document_chat: bool,
27
+ document_chat_chain_type: str,
28
+ llm,
29
+ retriever,
30
+ memory,
31
+ ):
32
+ if not use_document_chat:
33
+ return LLMChain(
34
+ prompt=chat_prompt,
35
+ llm=llm,
36
+ memory=memory,
37
+ ) | (lambda output: output["text"])
38
+
39
+ if document_chat_chain_type == "Q&A Generation":
40
+ return get_rag_qa_gen_chain(
41
+ retriever,
42
+ llm,
43
+ )
44
+ elif document_chat_chain_type == "Summarization":
45
+ return get_rag_summarization_chain(
46
+ prompt,
47
+ retriever,
48
+ llm,
49
+ )
50
+ else:
51
+ return RetrievalQA.from_chain_type(
52
+ llm=llm,
53
+ chain_type=document_chat_chain_type,
54
+ retriever=retriever,
55
+ memory=memory,
56
+ output_key="output_text",
57
+ ) | (lambda output: output["output_text"])
58
+
59
+
60
+ def get_llm(
61
+ provider: str,
62
+ model: str,
63
+ provider_api_key: str,
64
+ temperature: float,
65
+ max_tokens: int,
66
+ azure_available: bool,
67
+ azure_dict: dict[str, str],
68
+ ):
69
+ if azure_available and provider == "Azure OpenAI":
70
+ return AzureChatOpenAI(
71
+ openai_api_base=azure_dict["AZURE_OPENAI_BASE_URL"],
72
+ openai_api_version=azure_dict["AZURE_OPENAI_API_VERSION"],
73
+ deployment_name=azure_dict["AZURE_OPENAI_DEPLOYMENT_NAME"],
74
+ openai_api_key=azure_dict["AZURE_OPENAI_API_KEY"],
75
+ openai_api_type="azure",
76
+ model_version=azure_dict["AZURE_OPENAI_MODEL_VERSION"],
77
+ temperature=temperature,
78
+ streaming=True,
79
+ max_tokens=max_tokens,
80
+ )
81
+
82
+ elif provider_api_key:
83
+ if provider == "OpenAI":
84
+ return ChatOpenAI(
85
+ model_name=model,
86
+ openai_api_key=provider_api_key,
87
+ temperature=temperature,
88
+ streaming=True,
89
+ max_tokens=max_tokens,
90
+ )
91
+
92
+ elif provider == "Anthropic":
93
+ return ChatAnthropic(
94
+ model=model,
95
+ anthropic_api_key=provider_api_key,
96
+ temperature=temperature,
97
+ streaming=True,
98
+ max_tokens_to_sample=max_tokens,
99
+ )
100
+
101
+ elif provider == "Anyscale Endpoints":
102
+ return ChatAnyscale(
103
+ model_name=model,
104
+ anyscale_api_key=provider_api_key,
105
+ temperature=temperature,
106
+ streaming=True,
107
+ max_tokens=max_tokens,
108
+ )
109
+
110
+ return None
111
+
112
+
113
+ def get_texts_and_retriever(
114
+ uploaded_file_bytes: bytes,
115
+ chunk_size: int = DEFAULT_CHUNK_SIZE,
116
+ chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
117
+ k: int = DEFAULT_RETRIEVER_K,
118
+ ) -> Tuple[List[Document], BaseRetriever]:
119
+ with NamedTemporaryFile() as temp_file:
120
+ temp_file.write(uploaded_file_bytes)
121
+ temp_file.seek(0)
122
+
123
+ loader = PyPDFLoader(temp_file.name)
124
+ documents = loader.load()
125
+ text_splitter = RecursiveCharacterTextSplitter(
126
+ chunk_size=chunk_size,
127
+ chunk_overlap=chunk_overlap,
128
+ )
129
+ texts = text_splitter.split_documents(documents)
130
+ embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
131
+
132
+ bm25_retriever = BM25Retriever.from_documents(texts)
133
+ bm25_retriever.k = k
134
+
135
+ faiss_vectorstore = FAISS.from_documents(texts, embeddings)
136
+ faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": k})
137
+
138
+ ensemble_retriever = EnsembleRetriever(
139
+ retrievers=[bm25_retriever, faiss_retriever],
140
+ weights=[0.5, 0.5],
141
+ )
142
+
143
+ return texts, ensemble_retriever
144
+
145
+
146
+ class StreamHandler(BaseCallbackHandler):
147
+ def __init__(self, container, initial_text=""):
148
+ self.container = container
149
+ self.text = initial_text
150
+
151
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
152
+ self.text += token
153
+ self.container.markdown(self.text)