Spaces:
Runtime error
Runtime error
import re | |
import os | |
import logging | |
import gradio as gr | |
from typing import Set, List, Tuple | |
from huggingface_hub import InferenceClient | |
from langchain_openai import AzureChatOpenAI | |
from langchain.chains import LLMChain | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import SimpleSequentialChain | |
from langchain.chains import LLMSummarizationCheckerChain | |
# huggingface_key = os.getenv('HUGGINGFACE_KEY') | |
# print(huggingface_key) | |
# login(huggingface_key) # Huggingface api token | |
# Configure logging | |
logging.basicConfig(filename='factchecking.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
class FactChecking: | |
def __init__(self): | |
self.llm = AzureChatOpenAI( | |
azure_deployment = "ChatGPT" | |
) | |
self.client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
def format_prompt(self, question: str) -> str: | |
""" | |
Formats the input question into a specific structure for text generation. | |
Args: | |
question (str): The user's question to be formatted. | |
Returns: | |
str: The formatted prompt including instructions and the question. | |
""" | |
# Combine the instruction template with the user's question | |
prompt = f"[INST] you are the ai assitant your task is answr for the user question[/INST]" | |
prompt1 = f"[INST] {question} [/INST]" | |
return prompt+prompt1 | |
def mixtral_response(self,prompt, temperature=0.9, max_new_tokens=5000, top_p=0.95, repetition_penalty=1.0): | |
""" | |
Generates a response to the given prompt using text generation parameters. | |
Args: | |
prompt (str): The user's question. | |
temperature (float): Controls randomness in response generation. | |
max_new_tokens (int): The maximum number of tokens to generate. | |
top_p (float): Nucleus sampling parameter controlling diversity. | |
repetition_penalty (float): Penalty for repeating tokens. | |
Returns: | |
str: The generated response to the input prompt. | |
""" | |
# Adjust temperature and top_p values within acceptable ranges | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
# Simulating a call to a client's text generation API | |
formatted_prompt =self.format_prompt(prompt) | |
stream =self.client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
return output.replace("</s>","") | |
def extract_unique_sentences(self, text: str) -> Set[str]: | |
""" | |
Extracts unique sentences from the given text. | |
Args: | |
text (str): The input text. | |
Returns: | |
Set[str]: A set containing unique sentences. | |
""" | |
try: | |
# Tokenize the text into sentences using regex | |
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text) | |
logging.info("Sentence extraction completed successfully.") | |
# Return a list of sentences | |
return sentences | |
except Exception as e: | |
logging.error(f"Error occurred in extract_unique_sentences: {e}") | |
return set() | |
def find_different_sentences(self,answer): | |
splitted_answer=answer.split("\n\n") | |
predictions_=[] | |
for i in range(len(splitted_answer)): | |
if "True." in splitted_answer[i]: | |
prediction="factual" | |
context=splitted_answer[i].split("\n") | |
# print(context) | |
for j in range(len(context)): | |
t_sentence=context[j].replace(f"Fact {i+1}: ","") | |
predictions_.append((t_sentence, prediction)) | |
break | |
elif "False." in splitted_answer[i]: | |
prediction="hallucinated" | |
context=splitted_answer[i].split("\n") | |
for j in range(len(context)): | |
sentence=context[j].replace(f"Fact {i+1}: ","") | |
break | |
predictions_.append((sentence, prediction)) | |
return predictions_ | |
def extract_words(self, text: str) -> List[str]: | |
""" | |
Extracts words from the input text. | |
Parameters: | |
text (str): The input text. | |
Returns: | |
List[str]: A list containing the extracted words. | |
""" | |
try: | |
# Tokenize the text into words and non-word characters (including spaces) using regex | |
chunks = re.findall(r'\b\w+\b|\W+', text) | |
logging.info("Words extracted successfully.") | |
except Exception as e: | |
logging.error(f"An error occurred while extracting words: {str(e)}") | |
return [] | |
else: | |
return chunks | |
def label_words(self, text1: str, text2: str) -> List[Tuple[str, str]]: | |
""" | |
Labels words in text1 as 'factual' if they are present in text2, otherwise 'hallucinated'. | |
Parameters: | |
text1 (str): The first text. | |
text2 (str): The second text. | |
Returns: | |
List[Tuple[str, str]]: A list of tuples containing words from text1 and their labels. | |
""" | |
try: | |
# Extract chunks from both texts | |
chunks_text1 = self.extract_words(text1) | |
chunks_text2 = self.extract_words(text2) | |
# Convert chunks_text2 into a set for faster lookup | |
chunks_set_text2 = set(chunks_text2) | |
# Initialize labels list | |
labels = [] | |
# Iterate over chunks in text1 | |
for chunk in chunks_text1: | |
# Check if chunk is present in text2 | |
if chunk in chunks_set_text2: | |
labels.append((chunk, 'factual')) | |
else: | |
labels.append((chunk, 'hallucinated')) | |
logging.info("Words labeled successfully.") | |
return labels | |
except Exception as e: | |
logging.error(f"An error occurred while labeling words: {str(e)}") | |
return [] | |
def find_hallucinatted_sentence(self, question: str) -> Tuple[str, List[str]]: | |
""" | |
Finds hallucinated sentences in response to a given question. | |
Args: | |
question (str): The input question. | |
Returns: | |
Tuple[str, List[str]]: A tuple containing the original llama_result and a list of hallucinated sentences. | |
""" | |
try: | |
# Generate initial response using contract generator | |
mixtral_response = self.mixtral_response(question) | |
template = """Given some text, extract a list of facts from the text. | |
Format your output as a bulleted list. | |
Text: | |
{question} | |
Facts:""" | |
prompt_template = PromptTemplate(input_variables=["question"], template=template) | |
question_chain = LLMChain(llm=self.llm, prompt=prompt_template) | |
template = """You are an expert fact checker. You have been hired by a major news organization to fact check a very important story. | |
Here is a bullet point list of facts: | |
{statement} | |
For each fact, determine whether it is true or false about the subject. If you are unable to determine whether the fact is true or false, output "Undetermined". | |
If the fact is false, explain why.""" | |
prompt_template = PromptTemplate(input_variables=["statement"], template=template) | |
assumptions_chain = LLMChain(llm=self.llm, prompt=prompt_template) | |
overall_chain = SimpleSequentialChain(chains=[question_chain, assumptions_chain], verbose=True) | |
answer = overall_chain.run(mixtral_response) | |
# Find different sentences between original result and fact checking result | |
prediction_list = self.find_different_sentences(answer) | |
logging.info("Sentences comparison completed successfully.") | |
# Return the original result and list of hallucinated sentences | |
return mixtral_response,prediction_list | |
except Exception as e: | |
logging.error(f"Error occurred in find_hallucinatted_sentence: {e}") | |
return "", [] | |
def interface(self): | |
css=""".gradio-container {background: rgb(157,228,255); | |
background: radial-gradient(circle, rgba(157,228,255,1) 0%, rgba(18,115,106,1) 100%);}""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(""" | |
<center><h1 style="color:#fff">Detect Hallucination</h1></center>""") | |
with gr.Row(): | |
question = gr.Textbox(label="Question") | |
with gr.Row(): | |
button = gr.Button(value="Submit") | |
with gr.Row(): | |
mixtral_response = gr.Textbox(label="llm answer") | |
with gr.Row(): | |
highlighted_prediction = gr.HighlightedText( | |
label="Sentence Hallucination detection", | |
combine_adjacent=True, | |
color_map={"hallucinated": "red", "factual": "green"}, | |
show_legend=True) | |
button.click(self.find_hallucinatted_sentence,question,[mixtral_response,highlighted_prediction]) | |
demo.launch(debug=True) | |
hallucination_detection = FactChecking() | |
hallucination_detection.interface() |