Upload run.py
Browse files
    	
        run.py
    ADDED
    
    | @@ -0,0 +1,85 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import chromadb
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import gradio as gr
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            from huggingface_hub import InferenceClient
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            path='/Users/thiloid/Desktop/LSKI/ole_nest/Chatbot/LLM/chromaTS'
         | 
| 8 | 
            +
            if(os.path.exists(path)==False): path="/home/user/app/chromaTS"
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            print(path)
         | 
| 11 | 
            +
            #path='chromaTS'
         | 
| 12 | 
            +
            #settings = Settings(persist_directory=storage_path)
         | 
| 13 | 
            +
            #client = chromadb.Client(settings=settings)
         | 
| 14 | 
            +
            client = chromadb.PersistentClient(path=path)
         | 
| 15 | 
            +
            print(client.heartbeat()) 
         | 
| 16 | 
            +
            print(client.get_version())  
         | 
| 17 | 
            +
            print(client.list_collections()) 
         | 
| 18 | 
            +
            from chromadb.utils import embedding_functions
         | 
| 19 | 
            +
            default_ef = embedding_functions.DefaultEmbeddingFunction()
         | 
| 20 | 
            +
            sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="T-Systems-onsite/cross-en-de-roberta-sentence-transformer")#"VAGOsolutions/SauerkrautLM-Mixtral-8x7B-Instruct")
         | 
| 21 | 
            +
            #instructor_ef = embedding_functions.InstructorEmbeddingFunction(model_name="hkunlp/instructor-large", device="cuda")
         | 
| 22 | 
            +
            #print(str(client.list_collections()))
         | 
| 23 | 
            +
            collection = client.get_collection(name="chromaTS", embedding_function=sentence_transformer_ef)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def format_prompt(message, history):
         | 
| 29 | 
            +
              prompt = "" #"<s>"
         | 
| 30 | 
            +
              #for user_prompt, bot_response in history:
         | 
| 31 | 
            +
              #  prompt += f"[INST] {user_prompt} [/INST]"
         | 
| 32 | 
            +
              #  prompt += f" {bot_response}</s> "
         | 
| 33 | 
            +
              prompt += f"[INST] {message} [/INST]"
         | 
| 34 | 
            +
              return prompt
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            def response(
         | 
| 37 | 
            +
                prompt, history,temperature=0.9, max_new_tokens=500, top_p=0.95, repetition_penalty=1.0,
         | 
| 38 | 
            +
            ):
         | 
| 39 | 
            +
                temperature = float(temperature)
         | 
| 40 | 
            +
                if temperature < 1e-2: temperature = 1e-2
         | 
| 41 | 
            +
                top_p = float(top_p)
         | 
| 42 | 
            +
                generate_kwargs = dict(
         | 
| 43 | 
            +
                    temperature=temperature,
         | 
| 44 | 
            +
                    max_new_tokens=max_new_tokens,
         | 
| 45 | 
            +
                    top_p=top_p,
         | 
| 46 | 
            +
                    repetition_penalty=repetition_penalty,
         | 
| 47 | 
            +
                    do_sample=True,
         | 
| 48 | 
            +
                    seed=42,
         | 
| 49 | 
            +
                )
         | 
| 50 | 
            +
                addon=""
         | 
| 51 | 
            +
                results=collection.query(
         | 
| 52 | 
            +
                  query_texts=[prompt],
         | 
| 53 | 
            +
                  n_results=60,
         | 
| 54 | 
            +
                  #where={"source": "google-docs"}
         | 
| 55 | 
            +
                  #where_document={"$contains":"search_string"}
         | 
| 56 | 
            +
                )
         | 
| 57 | 
            +
                #print("REsults")
         | 
| 58 | 
            +
                #print(results)
         | 
| 59 | 
            +
                #print("_____")
         | 
| 60 | 
            +
                dists=["<br><small>(relevance: "+str(round((1-d)*100)/100)+";" for d in results['distances'][0]]
         | 
| 61 | 
            +
              
         | 
| 62 | 
            +
                #sources=["source: "+s["source"]+")</small>" for s in results['metadatas'][0]]
         | 
| 63 | 
            +
                results=results['documents'][0]
         | 
| 64 | 
            +
                print("TEst")
         | 
| 65 | 
            +
                print(results)
         | 
| 66 | 
            +
                print("_____")
         | 
| 67 | 
            +
                combination = zip(results,dists)
         | 
| 68 | 
            +
                combination = [' '.join(triplets) for triplets in combination]
         | 
| 69 | 
            +
                #print(str(prompt)+"\n\n"+str(combination))
         | 
| 70 | 
            +
                if(len(results)>1):
         | 
| 71 | 
            +
                  addon=" Bitte berücksichtige bei deiner Antwort ausschießlich folgende Auszüge aus unserer Datenbank, sofern sie für die Antwort erforderlich sind. Beantworte die Frage knapp und präzise. Ignoriere unpassende Datenbank-Auszüge OHNE sie zu kommentieren, zu erwähnen oder aufzulisten:\n"+"\n".join(results)
         | 
| 72 | 
            +
                system="Du bist ein deutschsprachiges KI-basiertes Studienberater Assistenzsystem, das zu jedem Anliegen möglichst geeignete Studieninformationen empfiehlt."+addon+"\n\nUser-Anliegen:"   
         | 
| 73 | 
            +
                formatted_prompt = format_prompt(system+"\n"+prompt,history)
         | 
| 74 | 
            +
                stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
         | 
| 75 | 
            +
                output = ""
         | 
| 76 | 
            +
                for response in stream:
         | 
| 77 | 
            +
                    output += response.token.text
         | 
| 78 | 
            +
                    yield output
         | 
| 79 | 
            +
                #output=output+"\n\n<br><details open><summary><strong>Sources</strong></summary><br><ul>"+ "".join(["<li>" + s + "</li>" for s in combination])+"</ul></details>"
         | 
| 80 | 
            +
                yield output
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            gr.ChatInterface(response, chatbot=gr.Chatbot(value=[[None,"Herzlich willkommen! Ich bin Chätti ein KI-basiertes Studienassistenzsystem, das für jede Anfrage die am besten Studieninformationen empfiehlt.<br>Erzähle mir, was du gerne tust!"]],render_markdown=True),title="German Studyhelper Chätti").queue().launch(share=True) #False, server_name="0.0.0.0", server_port=7864)
         | 
| 83 | 
            +
            print("Interface up and running!")
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             |