File size: 4,651 Bytes
971b34a
ae87366
 
971b34a
ae87366
c002e8b
8d4dc71
ae87366
 
8d4dc71
 
 
101dfab
ae87366
 
101dfab
971b34a
8d4dc71
ae87366
5799184
ae87366
 
 
 
 
 
 
101dfab
 
 
ae87366
 
 
 
101dfab
ae87366
 
 
c002e8b
 
 
 
ae87366
c002e8b
 
 
ae87366
c002e8b
 
8d4dc71
c002e8b
 
 
ae87366
a743f3e
8d4dc71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5d1b72
ae87366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b8eb3e
 
ae87366
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import pandas as pd
import json

from langchain.document_loaders import DataFrameLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain import HuggingFacePipeline

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

from trafilatura import fetch_url, extract
from trafilatura.spider import focused_crawler
from trafilatura.settings import use_config




def loading_website():
    return "Loading..."

def url_changes(url, pages_to_visit, urls_to_scrape, repo_id):
    to_visit, links = focused_crawler(url, max_seen_urls=pages_to_visit, max_known_urls=urls_to_scrape)
    print(f"{len(links)} to be crawled")

    config = use_config()
    config.set("DEFAULT", "EXTRACTION_TIMEOUT", "0")

    results_df = pd.DataFrame()
    for url in links:
        downloaded = fetch_url(url)
        if downloaded:
          result = extract(downloaded, output_format='json', config=config)
          result = json.loads(result)

          results_df = pd.concat([results_df, pd.DataFrame.from_records([result])])
    results_df.to_csv("./data.csv")
    
    df = pd.read_csv("./data.csv")
    loader = DataFrameLoader(df, page_content_column="text")
    documents = loader.load()
    print(f"{len(documents)} documents loaded") 

    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    texts = text_splitter.split_documents(documents)
    print(f"documents splitted into {len(texts)} chunks") 
    
    embeddings = SentenceTransformerEmbeddings(model_name="jhgan/ko-sroberta-multitask")

    persist_directory = './vector_db'
    db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory)
    retriever = db.as_retriever()

    MODEL = 'beomi/KoAlpaca-Polyglot-5.8B'
    model = AutoModelForCausalLM.from_pretrained(
        MODEL,
        torch_dtype="auto",
    )
    model.eval()
    pipe = pipeline(
        'text-generation',
        model=model,
        tokenizer=MODEL,
        max_length=512,
        temperature=0,
        top_p=0.95,
        repetition_penalty=1.15
    )
    llm = HuggingFacePipeline(pipeline=pipe)
    
    global qa
    qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
    return "Ready"

def add_text(history, text):
    history = history + [(text, None)]
    return history, ""

def bot(history):
    response = infer(history[-1][0])
    history[-1][1] = response['result']
    return history

def infer(question):

    query = question
    result = qa({"query": query})

    return result

css="""
#col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
"""

title = """
<div style="text-align: center;max-width: 700px;">
    <h1>Chat with your website</h1>
    <p style="text-align: center;">Enter target URL, click the "Load website to LangChain" button</p>
</div>
"""


with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.HTML(title)

        with gr.Column():
            target_url = gr.Textbox(label="Load URL", placeholder="Enter target URL here. EX: https://www.penta.co.kr/")
            #pdf_doc = gr.File(label="Load URL", file_types=['.pdf'], type="file")
            repo_id = gr.Dropdown(label="LLM", choices=["google/flan-ul2", "OpenAssistant/oasst-sft-1-pythia-12b", "beomi/KoAlpaca-Polyglot-12.8B"], value="google/flan-ul2")
            with gr.Row():
                langchain_status = gr.Textbox(label="Status", placeholder="", interactive=False)
                load_pdf = gr.Button("Load website to langchain")

        chatbot = gr.Chatbot([], elem_id="chatbot").style(height=350)
        question = gr.Textbox(label="Question", placeholder="Type your question and hit Enter ")
        submit_btn = gr.Button("Send message")
    #load_pdf.click(loading_pdf, None, langchain_status, queue=False)
    repo_id.change(url_changes, inputs=[target_url, gr.Number(value=5, visible=False), gr.Number(value=50, visible=False), repo_id], outputs=[langchain_status], queue=False)
    load_pdf.click(url_changes, inputs=[target_url, gr.Number(value=5, visible=False), gr.Number(value=50, visible=False), repo_id], outputs=[langchain_status], queue=False)
    question.submit(add_text, [chatbot, question], [chatbot, question]).then(
        bot, chatbot, chatbot
    )
    submit_btn.click(add_text, [chatbot, question], [chatbot, question]).then(
        bot, chatbot, chatbot
    )

demo.launch()