import gradio as gr from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough, RunnableLambda from langchain_core.messages import SystemMessage, AIMessage, HumanMessage from langchain_astradb import AstraDBChatMessageHistory, AstraDBStore, AstraDBVectorStore from langchain_openai import OpenAIEmbeddings, ChatOpenAI from elevenlabs import VoiceSettings from elevenlabs.client import ElevenLabs from openai import OpenAI from json import loads as json_loads import itertools import time import os AI = True if not hasattr(itertools, "batched"): def batched(iterable, n): "Batch data into lists of length n. The last batch may be shorter." # batched('ABCDEFG', 3) --> ABC DEF G it = iter(iterable) while True: batch = list(itertools.islice(it, n)) if not batch: return yield batch itertools.batched = batched def ai_setup(): global llm, prompt_chain, oai_client if AI: oai_client = OpenAI() llm = ChatOpenAI(model = "gpt-4o", temperature=0.8) embedding = OpenAIEmbeddings() vstore = AstraDBVectorStore( embedding=embedding, collection_name=os.environ.get("ASTRA_DB_COLLECTION"), token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"), api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"), ) retriever = vstore.as_retriever(search_kwargs={'k': 10}) prompt_template = os.environ.get("PROMPT_TEMPLATE") prompt = ChatPromptTemplate.from_messages([('system', prompt_template)]) prompt_chain = ( {"context": retriever, "question": RunnablePassthrough()} | RunnableLambda(format_context) | prompt # | llm # | StrOutputParser() ) else: retriever = RunnableLambda(just_read) def group_and_sort(documents): grouped = {} for document in documents: title = document.metadata["Title"] docs = grouped.get(title, []) grouped[title] = docs docs.append((document.page_content, document.metadata["range"])) for title, values in grouped.items(): values.sort(key=lambda doc:doc[1][0]) for title in grouped: text = '' prev_last = 0 for fragment, (start, last) in grouped[title]: if start < prev_last: text += fragment[prev_last-start:] elif start == prev_last: text += fragment else: text += ' [...] ' text += fragment prev_last = last grouped[title] = text return grouped def format_context(pipeline_state): """Print the state passed between Runnables in a langchain and pass it on""" context = '' documents = group_and_sort(pipeline_state["context"]) for title, text in documents.items(): context += f"\nTitle: {title}\n" context += text context += '\n\n---\n' pipeline_state["context"] = context return pipeline_state def just_read(pipeline_state): fname = "docs.pickle" import pickle return pickle.load(open(fname, "rb")) def new_state(): return gr.State({ "user" : None, "system" : None, "history" : None, }) def session_id(state: dict, request: gr.Request) -> str: return f'{state["user"]}_{request.session_hash}' class History: store = None def __init__(self, name:str, user:str, session_id:str, id:str = None): self.session_id = session_id self.name = name self.user = user self.astra_history = None if id: self.id = id else: self.id = f"{user}_{session_id}" self.create() @classmethod def get_store(self): if self.store is None: self.store = AstraDBStore( collection_name=f'{os.environ.get("ASTRA_DB_COLLECTION")}_sessions', token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"), api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"), ) return self.store @classmethod def from_dict(cls, id:str, data:dict): name = f":{id}" name = data.get("name", name) answer = cls(name, user=data["user"], id = id, session_id=data["session"]) return answer @classmethod def get_histories(cls, user:str): store = cls.get_store() histories = [] keys = [k for k in store.yield_keys(prefix=f"{user}_")] for id, history in zip(keys, store.mget(keys)): history = cls.from_dict(id = id, data = history) histories.append(history) return histories @classmethod def load(cls, id:str): data = cls.get_store().mget([id]) return cls.from_dict(id, data[0]) def __str__(self): return f"{self.id}:{self.name}" def create(self): history = { 'session' : self.session_id, 'user' : self.user, 'timestamp' : time.asctime(time.gmtime()), 'name' : self.name } self.get_store().mset([(self.id, history)]) @staticmethod def get_history_collection_name(): return f'{os.environ.get("ASTRA_DB_COLLECTION")}_chat_history' def get_astra_history(self): if self.astra_history is None: self.astra_history = AstraDBChatMessageHistory( session_id=self.id, collection_name=self.get_history_collection_name(), token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"), api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"), ) return self.astra_history def add(self, type:str, message): if type == "system": self.get_astra_history().add_message(message) elif type == "user": self.get_astra_history().add_user_message(message) elif type == "ai": self.get_astra_history().add_ai_message(message) def messages(self): return self.get_astra_history().messages def clear(self): self.get_astra_history().clear() def delete(self): self.clear() self.get_store().mdelete([self.id]) def auth(token, state, request: gr.Request): tokens=os.environ.get("APP_TOKENS") if not tokens: state["user"] = "anonymous" else: tokens=json_loads(tokens) state["user"] = tokens.get(token, None) return "", state AUTH_JS = """function auth_js(token, state) { if (!!document.location.hash) { token = document.location.hash document.location.hash="" } return [token, state] } """ def not_authenticated(state): answer = (state is None) or (not state['user']) if answer: gr.Warning("You need to authenticate first") return answer def list_histories(state): if not_authenticated(state): return gr.update() histories = History.get_histories(state["user"]) answer = [(h.name, h.id) for h in histories] return gr.update(choices=answer, value=None) def add_history(state, request, type, message, name:str = None): if not state["history"]: name = name or message[:60] state["history"] = History( name = name, user = state["user"], session_id = request.session_hash ) state["history"].add(type, message) def load_history(state, history_id): state["history"] = History.load(history_id) history = [] for msg in state["history"].messages(): if type(msg) is HumanMessage: history.append([msg.content, '']) elif type(msg) is AIMessage: if not history: history.append(['','']) last = history[-1] if last[1]: history.append(['', msg.content]) else: last[1] = msg.content if history and len(history[-1]) == 1: user_input = history[-1][0] history = history[:-1] else: user_input = '' if history: state["system"] = get_system_prompt(history[0][0]) return state, history, history, user_input # state, Chatbot, ChatInterface.state, ChatInterface.textbox def get_system_prompt(message): system_prompt = prompt_chain.invoke(message) return system_prompt.messages[0] def chat(message, history, state, request:gr.Request): if not_authenticated(state): yield "You need to authenticate first" else: if AI: if not history: state["system"] = get_system_prompt(message) system_prompt = state["system"] add_history(state, request, "user", message) messages = [system_prompt] for human, ai in history: messages.append(HumanMessage(human)) messages.append(AIMessage(ai)) messages.append(HumanMessage(message)) answer = '' for response in llm.stream(messages): answer += response.content yield answer+'…' else: add_history(state, request, "user", message) msg = f"{time.ctime()}: You said: {message}" answer = ' ' for word in msg.split(): answer += f' {word}' yield answer+'…' time.sleep(0.05) yield answer add_history(state, request, "ai", answer) def on_audio(path, state): if not_authenticated(state): return (gr.update(), None) else: if not path: return [gr.update(), None] if AI: text = oai_client.audio.transcriptions.create( model="whisper-1", file=open(path, "rb"), response_format="text" ) else: text = f"{time.ctime()}: You said something" return (text, None) def play_last(history, state): if not_authenticated(state): pass else: if len(history): voice_id = "IINmogebEQykLiDoSkd0" text = history[-1][1] lab11 = ElevenLabs() whatson=lab11.voices.get(voice_id) response = lab11.generate(text=text, voice=whatson, stream=True) yield from response def chat_change(history): if history: if not history[-1][1]: return gr.update(interactive=False) elif history[-1][1][-1] != '…': return gr.update(interactive=True) return gr.update() # play_last_btn TEXT_TALK = "🎤 Talk" TEXT_STOP = "⏹ Stop" def gr_setup(): theme = gr.Theme.from_hub("freddyaboulton/dracula_revamped@0.3.9") theme.set( color_accent_soft="#818eb6", # ChatBot.svelte / .user / .message-row.panel.user-row . neutral_500 -> neutral_200 background_fill_secondary="#6272a4", # ChatBot.svelte / .bot / .message-row.panel.bot-row . neutral_500 -> neutral_400 background_fill_primary="#818eb6", # DropdownOptions.svelte / item button_primary_text_color="*button_secondary_text_color", button_primary_background_fill="*button_secondary_background_fill") with gr.Blocks( title="Sherlock Holmes stories", fill_height=True, theme=theme, css="footer {visibility: hidden}" ) as app: state = new_state() chatbot = gr.Chatbot(show_label=False, render=False, scale=1) gr.HTML('