Spaces:
Sleeping
Sleeping
| ########################################################################################### | |
| # Title: Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB | |
| # Author: Andreas Fischer | |
| # Date: October 10th, 2024 | |
| # Last update: October 22th, 2024 | |
| ########################################################################################## | |
| import os | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel # chromaDB | |
| from datetime import datetime, date #add_doc, | |
| import chromadb #chromaDB | |
| from chromadb import Documents, EmbeddingFunction, Embeddings #chromaDB | |
| from chromadb.utils import embedding_functions #chromaDB | |
| import ocrmypdf #convertPDF | |
| from pypdf import PdfReader #convertPDF | |
| import re #format_prompt | |
| import gradio as gr # multimodal_response | |
| from huggingface_hub import InferenceClient #multimodal_response | |
| #--------------------------------------------------- | |
| # Specify models for text generation and embeddings | |
| #--------------------------------------------------- | |
| myModel="mistralai/Mixtral-8x7b-instruct-v0.1" | |
| #mod="mistralai/Mixtral-8x7b-instruct-v0.1" | |
| #tok=AutoTokenizer.from_pretrained(mod) #,token="hf_...") | |
| #cha=[{"role":"system","content":"A"},{"role":"user","content":"B"},{"role":"assistant","content":"C"}] | |
| #cha=[{"role":"user","content":"U1"},{"role":"assistant","content":"A1"},{"role":"user","content":"U2"},{"role":"assistant","content":"A2"}] | |
| #res=tok.apply_chat_template(cha) | |
| #print(tok.decode(res)) | |
| jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16) | |
| #jira.save_pretrained("jinaai_jina-embeddings-v2-base-de") | |
| device='cuda:0' if torch.cuda.is_available() else 'cpu' | |
| jina.to(device) #cuda:0 | |
| print(device) | |
| #----------------- | |
| # ChromaDB-client | |
| #----------------- | |
| class JinaEmbeddingFunction(EmbeddingFunction): | |
| def __call__(self, input: Documents) -> Embeddings: | |
| embeddings = jina.encode(input) #max_length=2048 | |
| return(embeddings.tolist()) | |
| dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db/" | |
| onPrem = True if(os.path.exists(dbPath)) else False | |
| if(onPrem==False): dbPath="/home/user/app/db/" | |
| print(dbPath) | |
| client = chromadb.PersistentClient(path=dbPath) | |
| print(client.heartbeat()) | |
| print(client.get_version()) | |
| print(client.list_collections()) | |
| jina_ef=JinaEmbeddingFunction() | |
| embeddingModel=jina_ef | |
| databases=[(date.today(),"0")] # start a list of databases | |
| #--------------------------------------------------------------------- | |
| # Function for formatting single message according to prompt template | |
| #--------------------------------------------------------------------- | |
| def format_prompt0(message, history): | |
| prompt = "<s>" | |
| #for user_prompt, bot_response in history: | |
| # prompt += f"[INST] {user_prompt} [/INST]" | |
| # prompt += f" {bot_response}</s> " | |
| prompt += f"[INST] {message} [/INST]" | |
| return prompt | |
| #------------------------------------------------------------------------- | |
| # Function for formatting multiturn-dialogue according to prompt template | |
| #------------------------------------------------------------------------- | |
| def format_prompt(message, history=None, system=None, RAGAddon=None, system2=None, zeichenlimit=None,historylimit=4, removeHTML=False, | |
| startOfString="<s>", template0=" [INST] {system} [/INST] </s>",template1=" [INST] {message} [/INST]",template2=" {response}</s>"): | |
| if zeichenlimit is None: zeichenlimit=1000000000 # :-) | |
| prompt = "" | |
| if RAGAddon is not None: | |
| system += RAGAddon | |
| if system is not None: | |
| prompt += template0.format(system=system) #"<s>" | |
| if history is not None: | |
| for user_message, bot_response in history[-historylimit:]: | |
| if user_message is None: user_message = "" | |
| if bot_response is None: bot_response = "" | |
| bot_response = re.sub("\n\n<details>((.|\n)*?)</details>","", bot_response) # remove RAG-compontents | |
| if removeHTML==True: bot_response = re.sub("<(.*?)>","\n", bot_response) # remove HTML-components in general (may cause bugs with markdown-rendering) | |
| if user_message is not None: prompt += template1.format(message=user_message[:zeichenlimit]) | |
| if bot_response is not None: prompt += template2.format(response=bot_response[:zeichenlimit]) | |
| if message is not None: prompt += template1.format(message=message[:zeichenlimit]) | |
| if system2 is not None: | |
| prompt += system2 | |
| return startOfString+prompt | |
| #-------------------------------------------- | |
| # Function for converting pdf-files to text | |
| #-------------------------------------------- | |
| def convertPDF(pdf_file, allow_ocr=False): | |
| reader = PdfReader(pdf_file) | |
| full_text = "" | |
| page_list = [] | |
| def extract_text_from_pdf(reader): | |
| full_text = "" | |
| page_list = [] | |
| page_count = 1 | |
| for idx, page in enumerate(reader.pages): | |
| text = page.extract_text() | |
| if len(text) > 0: | |
| page_list.append(text) | |
| #full_text += f"---- Page {idx} ----\n" + text + "\n\n" | |
| page_count += 1 | |
| return full_text.strip(), page_count, page_list | |
| # Check if there are any images | |
| image_count = sum(len(page.images) for page in reader.pages) | |
| # If there are images and not much content, you may want to perform OCR on the document | |
| if allow_ocr: | |
| print(f"{image_count} Images") | |
| if image_count > 0 and len(full_text) < 1000: | |
| out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf") | |
| ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True) | |
| reader = PdfReader(out_pdf_file) | |
| # Extract text: | |
| full_text, page_count, page_list = extract_text_from_pdf(reader) | |
| l = len(page_list) | |
| print(f"{l} Pages") | |
| # Extract metadata | |
| metadata = { | |
| "author": reader.metadata.author, | |
| "creator": reader.metadata.creator, | |
| "producer": reader.metadata.producer, | |
| "subject": reader.metadata.subject, | |
| "title": reader.metadata.title, | |
| "image_count": image_count, | |
| "page_count": page_count, | |
| "char_count": len(full_text), | |
| } | |
| return page_list, full_text, metadata | |
| #------------------------------------------ | |
| # Function for splitting text with overlap | |
| #------------------------------------------ | |
| def split_with_overlap0(text,chunk_size=3500, overlap=700): | |
| """ Split text in chunks based on number of characters (chunk_size) with chunks overlapping (overlap)""" | |
| chunks=[] | |
| step=max(1,chunk_size-overlap) | |
| for i in range(0,len(text),step): | |
| end=min(i+chunk_size,len(text)) | |
| chunks.append(text[i:end]) | |
| return chunks | |
| import re | |
| def split_with_overlap(text, chunk_size=3500, overlap=700, pattern=r'([.!;?][ \n\r]|[\n\r]{2,})', variant=1, verbose=False): | |
| """ Split text in chunks based on regex (pattern) matches. By default the pattern is '([.!;?][ \\n\\r]|[\\n\\r]{2,})' Chunks are no longer than a certain number of characters (chunk_size) with chunks overlapping (overlap). | |
| By default (variant=1) chunking is based on complete sentences, but it's also possible to split only within the left overlap region and within the rest of the chunk-size (variant==2) or strictly within both overlap-regions (variant=3). | |
| """ | |
| chunks = [] | |
| overlap=min(overlap,chunk_size) # Overlap kann nicht größer sein als chunk_size | |
| step = max(1, chunk_size - overlap) # step richtet sich nach chunk_size und overlap | |
| def find_pattern(text): # Funktion zur Suche nach dem Muster | |
| return re.search(pattern, text) | |
| i, lastEnd = 0,0 | |
| while i<len(text): | |
| print("i="+str(i)) | |
| end = min(i + chunk_size, len(text)) | |
| pattern_match = find_pattern(text[i:end]) # erstes Vorkommnis (if any) | |
| matchesStart = [x.start() for x in re.finditer(pattern, text[i:end])] # start aller matches | |
| matchesEnd = [x.start() for x in re.finditer(pattern, text[i:end])] # end aller matches | |
| step = max(1, chunk_size - overlap) # Normalerweise beträgt ein Step chunk_size - overlap | |
| if pattern_match: # Wenn (mindestens) ein Satzzeichen gefunden wurde | |
| for s in matchesStart: # gehe jedes Satzzeichen durch | |
| if ((variant<=2 and s>=overlap) or (variant==3 and s>=overlap and s>(chunk_size-overlap))): # wenn das Satzzeichen nicht im Overlap links liegt (1) oder zusätzlich im reechten Overlap liegt (2) - wobei letzteres unvollständige Sätze bedeuten kann | |
| end=s+i+1 # Setze end auf den Start des Patterns/Satzzeichens im gesamten Text | |
| if(verbose==True): print("***move end:"+str(end)+"; step="+str(step)) | |
| if(s<(chunk_size-overlap)):step=min(step,max(1,s-overlap)) # Springe mit step höchstens zum Ende des Satzzeichens (nur erforderlich, wenn end nicht im Overlap) | |
| if ((variant==1 and i>0) or (variant>=2 and pattern_match.start()<overlap and i>0)): # wenn das erste Satzzeichen im Overlap liegt | |
| i=i+pattern_match.start()+1 # Verzichte auf Textteile vor dem ersten Satzzeichen | |
| if(verbose==True): print("i="+str(i)+"; end="+str(end)+"; step="+str(step)+"; len="+str(len(text))+"; match="+str(pattern_match)+"; text="+text[i:end]+"; rest="+text[end:]) | |
| if(end>lastEnd): # wenn das Ende sich verschoben hat (und nicht nur den Satzbeginn zu einem bereits bekannten Satz abschneidet) | |
| chunks.append(text[i:end]) | |
| lastEnd=end | |
| if(verbose==True): print("Text at position "+str(i)+": "+text[i:end]) | |
| i += step | |
| if(len(text[end:])>0): chunks.append(text[end:]) # Ergänze am ende etwaigen Rest | |
| return chunks | |
| fiveChars= "(?<![ \n\(]bspw|[ \n]inkl)" | |
| fourChars= "(?<![ \n\(]sog|[ \n]Mio|[ \n]Mrd|[ \n]Tsd|[ \n]Tel)" | |
| threeChars= "(?<!www|bzw|etc|ggf|[ \n\(]al|[ \n\(]St|[ \n\(]dh|[ \n\(]va|[ \n\(]ca|[ \n\(]Dr|[ \n\(]Hr|[ \n\(]Fr|[0-9]ff)" | |
| twoChars= "(?<![ \n\(][A-Za-zΆ-Ωά-ωäöüß])" | |
| oneChars= "(?<![0-9.])" | |
| sentenceRegex="(?<=[^.]{4})"+fiveChars+fourChars+threeChars+twoChars+oneChars+"[.?!](?![A-Za-zΆ-Ωά-ωäöüß0-9.!?'\"])" | |
| sectionRegex="\n[ ]*\n[\n ]*" | |
| splitRegex="("+sentenceRegex+"|"+sectionRegex+")" | |
| #--------------------------------------------------------------- | |
| # Function for adding docs to ChromaDB and/or return collection | |
| #--------------------------------------------------------------- | |
| def add_doc(path, session): | |
| global device | |
| print("def add_doc!") | |
| print(path) | |
| anhang=False | |
| if(str.lower(path).endswith(".pdf") and os.path.exists(path)): | |
| doc=convertPDF(path) | |
| if(len(doc[0])>5): | |
| if(not "cuda" in device): | |
| doc="\n\n".join(doc[0][0:5]) | |
| gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing excerpt (first 5 pages on CPU setups)!") | |
| else: | |
| doc="\n\n".join(doc[0]) | |
| gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing!") | |
| else: | |
| doc="\n\n".join(doc[0]) | |
| gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing!") | |
| anhang=True | |
| else: | |
| gr.Info("No PDF attached - answer based on DB_"+str(session)+".") | |
| client = chromadb.PersistentClient(path=dbPath) | |
| print(str(client.list_collections())) | |
| print(str(session)) | |
| dbName="DB_"+str(session) | |
| if(not "name="+dbName in str(client.list_collections())): | |
| # client.delete_collection(name=dbName) | |
| collection = client.create_collection( | |
| name=dbName, | |
| embedding_function=embeddingModel, | |
| metadata={"hnsw:space": "cosine"}) | |
| else: | |
| collection = client.get_collection( | |
| name=dbName, embedding_function=embeddingModel) | |
| if(anhang==True): | |
| corpus=split_with_overlap(doc,3500,700,pattern=splitRegex) | |
| print("Length of corpus: "+str(len(corpus))) | |
| print("Corpus:"+str(corpus)) | |
| then = datetime.now() | |
| x=collection.get(include=[])["ids"] | |
| print(len(x)) | |
| if(len(x)==0): | |
| chunkSize=40000 | |
| for i in range(round(len(corpus)/chunkSize+0.5)): #0 is first batch, 3 is last (incomplete) batch given 133497 texts | |
| print("embed batch "+str(i)+" of "+str(round(len(corpus)/chunkSize+0.5))) | |
| ids=list(range(i*chunkSize,(i*chunkSize+chunkSize))) | |
| batch=corpus[i*chunkSize:(i*chunkSize+chunkSize)] | |
| textIDs=[str(id) for id in ids[0:len(batch)]] | |
| ids=[str(id+len(x)+1) for id in ids[0:len(batch)]] # id refers to chromadb-unique ID | |
| collection.add(documents=batch, ids=ids, | |
| metadatas=[{"date": str("2024-10-10")} for b in batch]) #"textID":textIDs, "id":ids, | |
| print("finished batch "+str(i)+" of "+str(round(len(corpus)/40000+0.5))) | |
| now = datetime.now() | |
| gr.Info(f"Indexing complete!") | |
| print(now-then) #zu viel GB für sentences (GPU), bzw. 0:00:10.375087 für chunks | |
| return(collection) | |
| #-------------------------------------------------------- | |
| # Function for response to user queries and pot. addenda | |
| #-------------------------------------------------------- | |
| def multimodal_response(message, history, dropdown, hfToken, request: gr.Request): | |
| print("def multimodal response!") | |
| if(hfToken.startswith("hf_")): # use HF-hub with custom token if token is provided | |
| inferenceClient = InferenceClient(model=myModel, token=hfToken) | |
| else: | |
| inferenceClient = InferenceClient(myModel) | |
| global databases | |
| if request: | |
| session=request.session_hash | |
| else: | |
| session="0" | |
| length=str(len(history)) | |
| print(databases) | |
| if(not databases[-1][1]==session): | |
| databases.append((date.today(),session)) | |
| #print(databases) | |
| query=message["text"] | |
| if(len(message["files"])>0): # is there at least one file attached? | |
| collection=add_doc(message["files"][0], session) | |
| else: # otherwise, you still want to get the collection with the session-based db | |
| collection=add_doc(message["text"], session) | |
| client = chromadb.PersistentClient(path=dbPath) | |
| print(str(client.list_collections())) | |
| x=collection.get(include=[])["ids"] | |
| context=collection.query(query_texts=[query], n_results=1) | |
| context=["<Kontext "+str(i)+"> "+str(c)+"</Kontext "+str(i)+">" for i,c in enumerate(context["documents"][0])] | |
| gr.Info("Kontext:\n"+str(context)) | |
| generate_kwargs = dict( | |
| temperature=float(0.9), | |
| max_new_tokens=5000, | |
| top_p=0.95, | |
| repetition_penalty=1.0, | |
| do_sample=True, | |
| seed=42, | |
| ) | |
| system="Mit Blick auf das folgende Gespräch und den relevanten Kontext, antworte auf die aktuelle Frage des Nutzers. "+\ | |
| "Antworte ausschließlich auf Basis der Informationen im Kontext.\n\nKontext:\n\n"+\ | |
| str("\n\n".join(context)) | |
| #"Given the following conversation, relevant context, and a follow up question, "+\ | |
| #"reply with an answer to the current question the user is asking. "+\ | |
| #"Return only your response to the question given the above information "+\ | |
| #"following the users instructions as needed.\n\nContext:"+\ | |
| print(system) | |
| #formatted_prompt = format_prompt0(system+"\n"+query, history) | |
| formatted_prompt = format_prompt(query, history,system=system) | |
| print(formatted_prompt) | |
| output = "" | |
| try: | |
| stream = inferenceClient.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
| for response in stream: | |
| output += response.token.text | |
| yield output | |
| except Exception as e: | |
| output = "Für weitere Antworten von der KI gebe bitte einen gültigen HuggingFace-Token an." | |
| if(len(context)>0): | |
| output += "\nBis dahin helfen dir hoffentlich die folgenden Quellen weiter:" | |
| yield output | |
| print(str(e)) | |
| if(len(context)>0): | |
| output=output+"\n\n<br><details open><summary><strong>Quellen</strong></summary><br><ul>"+ "".join(["<li>" + c + "</li>" for c in context])+"</ul></details>" | |
| yield output | |
| #------------------------------ | |
| # Launch Gradio-ChatInterface | |
| #------------------------------ | |
| i=gr.ChatInterface(multimodal_response, | |
| title="Frag dein PDF", | |
| multimodal=True, | |
| additional_inputs=[ | |
| gr.Dropdown( | |
| info="Wähle eine Variante", | |
| choices=["1","2","3"], | |
| value="1", | |
| label="Variante"), | |
| gr.Textbox( | |
| value="", | |
| label="HF_token"), | |
| ]) | |
| i.launch() #allowed_paths=["."]) | |