File size: 4,012 Bytes
cff1b65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a8ca1b
 
 
 
 
cff1b65
 
 
 
 
 
2a8ca1b
 
cff1b65
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from datasets import concatenate_datasets, load_dataset
import gc
import gradio as gr
from peft import PeftModel, PeftConfig
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.docstore.document import Document
from langchain.llms import HuggingFacePipeline
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts import PromptTemplate
import torch
import random
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.memory import ConversationBufferMemory
import requests
import re

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Samsum dataset for generating questions
train_dataset = load_dataset("samsum", split='train', trust_remote_code=True)
val_dataset = load_dataset("samsum", split='validation', trust_remote_code=True)
samsum_dataset = concatenate_datasets([train_dataset, val_dataset])

model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
rlhf_model_path = "raghav-gaggar/PEFT_RLHF_TextSummarizer"
config = PeftConfig.from_pretrained(rlhf_model_path)
ppo_model = PeftModel.from_pretrained(base_model, rlhf_model_path).to(device)
merged_model = ppo_model.merge_and_unload().to(device)

base_model.eval()
ppo_model.eval()
merged_model.eval()

dialogsum_dataset = load_dataset("knkarthick/dialogsum", trust_remote_code=True)

def format_dialogsum_as_document(example):
    return Document(page_content=f"Dialogue:\n {example['dialogue']}\n\nSummary: {example['summary']}")

# Create documents from DialogSum dataset
documents = []
for split in ['train', 'validation', 'test']:
    documents.extend([format_dialogsum_as_document(example) for example in dialogsum_dataset[split]])

# Split the documents into chunks
text_splitter = CharacterTextSplitter(chunk_size=5200, chunk_overlap=0)
docs = text_splitter.split_documents(documents)

# Create embeddings and vector store for DialogSum documents
embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"},
    encode_kwargs={"batch_size": 32}
)

vector_store = FAISS.from_documents(docs, embeddings)

# Initialize retriever for DialogSum documents
retriever = vector_store.as_retriever(search_kwargs={"k": 1})

prompt_template = """
Concisely summarize the dialogue in the end, like the example provided -

Example -
{context}

Dialogue to be summarized:
{question}

Summary:"""

PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)

# Create a Hugging Face pipeline
summarization_pipeline = pipeline(
    "summarization",
    model=merged_model,
    tokenizer=tokenizer,
    max_length=150,
    min_length=20,
    do_sample=False,
)

# Wrap the pipeline in a LangChain LLM
llm = HuggingFacePipeline(pipeline=summarization_pipeline)

qa_chain = RetrievalQA.from_chain_type(
    llm, retriever=retriever, chain_type_kwargs={"prompt": PROMPT}
)

# Function for Gradio interface
def summarize_conversation(question):
    result = qa_chain({"query": question})
    return result["result"]

example = [
    ["Conversation 1", "Amanda: I baked cookies. Do you want some? \nJerry: Sure! \nAmanda: I'll bring you tomorrow :-)"]
]

# Create Gradio interface
iface = gr.Interface(
    fn=summarize_conversation,
    inputs=gr.Textbox(lines=10, label="Enter conversation here"),
    outputs=gr.Textbox(label="Summary"),
    title="Conversation Summarizer",
    description="Enter a conversation, and the AI will provide a concise summary.",
    examples = example
)

# Launch the app
iface.launch()