File size: 9,860 Bytes
e0b0a1c
06b19dd
e0b0a1c
06b19dd
e0b0a1c
06b19dd
 
e0b0a1c
06b19dd
e0b0a1c
 
06b19dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import re
import logging
import gradio as gr
from huggingface_hub import login
from typing import Set, List, Tuple
from huggingface_hub import InferenceClient
from langchain_openai import AzureChatOpenAI

from langchain.chains import LLMSummarizationCheckerChain


huggingface_key = os.getenv('HUGGINGFACE_KEY')
login(huggingface_key) # Huggingface api token

# Configure logging
logging.basicConfig(filename='factchecking.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
class FactChecking:

    def __init__(self):
        
        self.llm = AzureChatOpenAI(
          azure_deployment = "ChatGPT"
        )

        self.client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

    def format_prompt(self, question: str) -> str:
        """
        Formats the input question into a specific structure for text generation.

        Args:
            question (str): The user's question to be formatted.

        Returns:
            str: The formatted prompt including instructions and the question.
        """
        # Combine the instruction template with the user's question
        prompt = f"[INST] you are the ai assitant your task is answr for the user question[/INST]"
        prompt1 = f"[INST] {question} [/INST]"
        return prompt+prompt1

    def mixtral_response(self,prompt, temperature=0.9, max_new_tokens=5000, top_p=0.95, repetition_penalty=1.0):
        """
        Generates a response to the given prompt using text generation parameters.

        Args:
            prompt (str): The user's question.
            temperature (float): Controls randomness in response generation.
            max_new_tokens (int): The maximum number of tokens to generate.
            top_p (float): Nucleus sampling parameter controlling diversity.
            repetition_penalty (float): Penalty for repeating tokens.

        Returns:
            str: The generated response to the input prompt.
        """

        # Adjust temperature and top_p values within acceptable ranges
        temperature = float(temperature)
        if temperature < 1e-2:
            temperature = 1e-2
        top_p = float(top_p)

        generate_kwargs = dict(
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            seed=42,
        )
        # Simulating a call to a client's text generation API
        formatted_prompt =self.format_prompt(prompt)
        stream =self.client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
        output = ""

        for response in stream:
            output += response.token.text
       
        return output.replace("</s>","")

    def extract_unique_sentences(self, text: str) -> Set[str]:
        """
        Extracts unique sentences from the given text.

        Args:
            text (str): The input text.

        Returns:
            Set[str]: A set containing unique sentences.
        """
        try:
          # Tokenize the text into sentences using regex
          sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
          logging.info("Sentence extraction completed successfully.")
          # Return a list of sentences
          return sentences
        except Exception as e:
            logging.error(f"Error occurred in extract_unique_sentences: {e}")
            return set()

    def find_different_sentences(self, text1: str, text2: str) -> List[Tuple[str, str]]:
        """
        Finds sentences that are different between two texts.

        Args:
            text1 (str): The first text.
            text2 (str): The second text.

        Returns:
            List[Tuple[str, str]]: A list of tuples containing sentences and their labels.
        """
        try:
          sentences_text1 = self.extract_unique_sentences(text1)
          sentences_text2 = self.extract_unique_sentences(text2)
          # Initialize labels list
          labels = []
          # Iterate over sentences in text1
          for sentence in sentences_text1:
              if sentence in sentences_text2:
                  # If sentence is common to both texts, assign 'factual' label
                  labels.append((sentence, 'factual'))
              else:
                  # If sentence is unique to text1, assign 'hallucinated' label
                  labels.append((sentence, 'hallucinated'))
          logging.info("Sentence comparison completed successfully.")
          return labels
        except Exception as e:
            logging.error(f"Error occurred in find_different_sentences: {e}")
            return []

    def extract_words(self, text: str) -> List[str]:
        """
        Extracts words from the input text.

        Parameters:
            text (str): The input text.

        Returns:
            List[str]: A list containing the extracted words.
        """
        try:
          # Tokenize the text into words and non-word characters (including spaces) using regex
          chunks = re.findall(r'\b\w+\b|\W+', text)
          logging.info("Words extracted successfully.")
        except Exception as e:
            logging.error(f"An error occurred while extracting words: {str(e)}")
            return []
        else:
            return chunks

    def label_words(self, text1: str, text2: str) -> List[Tuple[str, str]]:
        """
        Labels words in text1 as 'factual' if they are present in text2, otherwise 'hallucinated'.

        Parameters:
            text1 (str): The first text.
            text2 (str): The second text.

        Returns:
            List[Tuple[str, str]]: A list of tuples containing words from text1 and their labels.
        """
        try:
          # Extract chunks from both texts
          chunks_text1 = self.extract_words(text1)
          chunks_text2 = self.extract_words(text2)
          # Convert chunks_text2 into a set for faster lookup
          chunks_set_text2 = set(chunks_text2)
          # Initialize labels list
          labels = []
          # Iterate over chunks in text1
          for chunk in chunks_text1:
              # Check if chunk is present in text2
              if chunk in chunks_set_text2:
                  labels.append((chunk, 'factual'))
              else:
                  labels.append((chunk, 'hallucinated'))
          logging.info("Words labeled successfully.")
          return labels
        except Exception as e:
            logging.error(f"An error occurred while labeling words: {str(e)}")
            return []

    def find_hallucinatted_sentence(self, question: str) -> Tuple[str, List[str]]:
        """
        Finds hallucinated sentences in response to a given question.

        Args:
            question (str): The input question.

        Returns:
            Tuple[str, List[str]]: A tuple containing the original llama_result and a list of hallucinated sentences.
        """
        try:

          # Generate initial response using contract generator
          mixtral_response = self.mixtral_response(question)

          # Create checker chain for summarization checking
          checker_chain = LLMSummarizationCheckerChain.from_llm(self.llm, verbose=True, max_checks=2)

          # Run fact checking on the generated result
          fact_checking_result = checker_chain.run(mixtral_response)

          # Find different sentences between original result and fact checking result
          prediction_list = self.find_different_sentences(mixtral_response, fact_checking_result)

          #word prediction list
          word_prediction_list = self.label_words(mixtral_response, fact_checking_result)

          logging.info("Sentences comparison completed successfully.")
          # Return the original result and list of hallucinated sentences
          return mixtral_response,fact_checking_result,prediction_list,word_prediction_list

        except Exception as e:
            logging.error(f"Error occurred in find_hallucinatted_sentence: {e}")
            return "", []

    def interface(self):
      css=""".gradio-container {background: rgb(157,228,255);
        background: radial-gradient(circle, rgba(157,228,255,1) 0%, rgba(18,115,106,1) 100%);}"""

      with gr.Blocks(css=css) as demo:
        gr.HTML("""
            <center><h1 style="color:#fff">Detect Hallucination</h1></center>""")
        with gr.Row():
          question = gr.Textbox(label="Question")
        with gr.Row():
          button = gr.Button(value="Submit")
        with gr.Row():
          with gr.Column(scale=0.50):
            mixtral_response = gr.Textbox(label="llm answer")
          with gr.Column(scale=0.50):
            fact_checking_result = gr.Textbox(label="Corrected Result")
        with gr.Row():
          with gr.Column(scale=0.50):
            highlighted_prediction = gr.HighlightedText(
                                  label="Sentence Hallucination detection",
                                  combine_adjacent=True,
                                  color_map={"hallucinated": "red", "factual": "green"},
                                  show_legend=True)
          with gr.Column(scale=0.50):
            word_highlighted_prediction = gr.HighlightedText(
                                  label="Word Hallucination detection",
                                  combine_adjacent=True,
                                  color_map={"hallucinated": "red", "factual": "green"},
                                  show_legend=True)
        button.click(self.find_hallucinatted_sentence,question,[mixtral_response,fact_checking_result,highlighted_prediction,word_highlighted_prediction])
      demo.launch(debug=True)


hallucination_detection = FactChecking()
hallucination_detection.interface()