File size: 4,433 Bytes
971b34a
ae87366
 
2657429
971b34a
ae87366
c002e8b
2657429
 
ae87366
 
101dfab
ae87366
 
101dfab
971b34a
2657429
ae87366
5799184
ae87366
 
 
 
 
 
 
101dfab
 
 
ae87366
 
 
 
101dfab
ae87366
 
 
c002e8b
 
 
 
ae87366
c002e8b
 
 
ae87366
c002e8b
 
2657429
c002e8b
 
 
ae87366
a743f3e
2657429
 
c5d1b72
ae87366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b8eb3e
 
ae87366
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import pandas as pd
import json
import os

from langchain.document_loaders import DataFrameLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceHubEmbeddings
from langchain.llms import HuggingFaceHub
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA

from trafilatura import fetch_url, extract
from trafilatura.spider import focused_crawler
from trafilatura.settings import use_config

HF_TOKEN = os.environ.get("HF_TOKEN", None)


def loading_website():
    return "Loading..."

def url_changes(url, pages_to_visit, urls_to_scrape, repo_id):
    to_visit, links = focused_crawler(url, max_seen_urls=pages_to_visit, max_known_urls=urls_to_scrape)
    print(f"{len(links)} to be crawled")

    config = use_config()
    config.set("DEFAULT", "EXTRACTION_TIMEOUT", "0")

    results_df = pd.DataFrame()
    for url in links:
        downloaded = fetch_url(url)
        if downloaded:
          result = extract(downloaded, output_format='json', config=config)
          result = json.loads(result)

          results_df = pd.concat([results_df, pd.DataFrame.from_records([result])])
    results_df.to_csv("./data.csv")
    
    df = pd.read_csv("./data.csv")
    loader = DataFrameLoader(df, page_content_column="text")
    documents = loader.load()
    print(f"{len(documents)} documents loaded") 

    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    texts = text_splitter.split_documents(documents)
    print(f"documents splitted into {len(texts)} chunks") 
    
    embeddings = HuggingFaceHubEmbeddings(model_name="jhgan/ko-sroberta-multitask", huggingfacehub_api_token=HF_TOKEN)

    persist_directory = './vector_db'
    db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory)
    retriever = db.as_retriever()

    #MODEL = 'beomi/KoAlpaca-Polyglot-5.8B'
    llm = HuggingFaceHub(repo_id="beomi/KoAlpaca-Polyglot-5.8B", model_kwargs={"temperature":0.6, "max_new_tokens":250}, huggingfacehub_api_token=HF_TOKEN)
    
    global qa
    qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
    return "Ready"

def add_text(history, text):
    history = history + [(text, None)]
    return history, ""

def bot(history):
    response = infer(history[-1][0])
    history[-1][1] = response['result']
    return history

def infer(question):

    query = question
    result = qa({"query": query})

    return result

css="""
#col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
"""

title = """
<div style="text-align: center;max-width: 700px;">
    <h1>Chat with your website</h1>
    <p style="text-align: center;">Enter target URL, click the "Load website to LangChain" button</p>
</div>
"""


with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.HTML(title)

        with gr.Column():
            target_url = gr.Textbox(label="Load URL", placeholder="Enter target URL here. EX: https://www.penta.co.kr/")
            #pdf_doc = gr.File(label="Load URL", file_types=['.pdf'], type="file")
            repo_id = gr.Dropdown(label="LLM", choices=["google/flan-ul2", "OpenAssistant/oasst-sft-1-pythia-12b", "beomi/KoAlpaca-Polyglot-12.8B"], value="google/flan-ul2")
            with gr.Row():
                langchain_status = gr.Textbox(label="Status", placeholder="", interactive=False)
                load_pdf = gr.Button("Load website to langchain")

        chatbot = gr.Chatbot([], elem_id="chatbot").style(height=350)
        question = gr.Textbox(label="Question", placeholder="Type your question and hit Enter ")
        submit_btn = gr.Button("Send message")
    #load_pdf.click(loading_pdf, None, langchain_status, queue=False)
    repo_id.change(url_changes, inputs=[target_url, gr.Number(value=5, visible=False), gr.Number(value=50, visible=False), repo_id], outputs=[langchain_status], queue=False)
    load_pdf.click(url_changes, inputs=[target_url, gr.Number(value=5, visible=False), gr.Number(value=50, visible=False), repo_id], outputs=[langchain_status], queue=False)
    question.submit(add_text, [chatbot, question], [chatbot, question]).then(
        bot, chatbot, chatbot
    )
    submit_btn.click(add_text, [chatbot, question], [chatbot, question]).then(
        bot, chatbot, chatbot
    )

demo.launch()