Spaces:
Sleeping
Sleeping
File size: 8,462 Bytes
876b12f 55cdb25 876b12f 55cdb25 876b12f 55cdb25 876b12f 55cdb25 876b12f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import logging
from typing import Dict, Any, List
from transformers import pipeline
from transformers import AutoTokenizer
import numpy as np
import nltk
from nltk.tokenize import sent_tokenize
logger = logging.getLogger(__name__)
class HeadlineAnalyzer:
def __init__(self):
"""Initialize the NLI model for contradiction detection."""
self.nli_pipeline = pipeline("text-classification", model="roberta-large-mnli")
self.tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
self.max_length = 512
def _split_content(self, headline: str, content: str) -> List[str]:
"""Split content into sections that fit within token limit."""
content_words = content.split()
sections = []
current_section = []
# Account for headline and [SEP] token in the max length
headline_tokens = len(self.tokenizer.encode(headline))
sep_tokens = len(self.tokenizer.encode("[SEP]")) - 2 # -2 because encode adds special tokens
max_content_tokens = self.max_length - headline_tokens - sep_tokens
# Process words into sections
for word in content_words:
current_section.append(word)
# Check if current section is approaching token limit
current_text = " ".join(current_section)
if len(self.tokenizer.encode(current_text)) >= max_content_tokens:
# Remove last word (it might make us go over limit)
current_section.pop()
sections.append(" ".join(current_section))
# Start new section with 20% overlap for context
overlap_start = max(0, len(current_section) - int(len(current_section) * 0.2))
current_section = current_section[overlap_start:]
current_section.append(word)
# Add any remaining content as the last section
if current_section:
sections.append(" ".join(current_section))
logger.info(f"""Content Splitting:
- Original content length: {len(content_words)} words
- Split into {len(sections)} sections
- Headline uses {headline_tokens} tokens
- Available tokens per section: {max_content_tokens}
""")
return sections
def _analyze_section(self, headline: str, section: str) -> Dict[str, float]:
"""Analyze a single section of content."""
# Use a more robust method for sentence splitting
nltk.download('punkt')
sentences = sent_tokenize(section)
flagged_phrases = []
for sentence in sentences:
input_text = f"{headline} [SEP] {sentence}"
result = self.nli_pipeline(input_text, top_k=None)
scores = {item['label']: item['score'] for item in result}
# Log the model output for debugging
logger.info(f"Sentence: {sentence}")
logger.info(f"Scores: {scores}")
# Set the threshold for contradiction to anything higher than 0.1
if scores.get('CONTRADICTION', 0) > 0.1: # Threshold set to > 0.1
flagged_phrases.append(sentence)
# Adjust the headline_vs_content_score based on contradictions
contradiction_penalty = len(flagged_phrases) * 0.1 # Example penalty per contradiction
adjusted_score = max(0, scores.get('ENTAILMENT', 0) - contradiction_penalty)
logger.info("\nSection Analysis:")
logger.info("-"*30)
logger.info(f"Section preview: {section[:100]}...")
for label, score in scores.items():
logger.info(f"Label: {label:<12} Score: {score:.3f}")
return {"scores": scores, "flagged_phrases": flagged_phrases, "adjusted_score": adjusted_score}
def analyze(self, headline: str, content: str) -> Dict[str, Any]:
"""Analyze how well the headline matches the content using an AI model."""
try:
logger.info("\n" + "="*50)
logger.info("HEADLINE ANALYSIS STARTED")
logger.info("="*50)
# Handle empty inputs
if not headline.strip() or not content.strip():
logger.warning("Empty headline or content provided")
return {
"headline_vs_content_score": 0,
"entailment_score": 0,
"contradiction_score": 0,
"contradictory_phrases": []
}
# Split content if too long
content_tokens = len(self.tokenizer.encode(content))
if content_tokens > self.max_length:
logger.warning(f"""
Content Length Warning:
- Total tokens: {content_tokens}
- Max allowed: {self.max_length}
- Splitting into sections...
""")
sections = self._split_content(headline, content)
# Analyze each section
section_scores = []
for i, section in enumerate(sections, 1):
logger.info(f"\nAnalyzing section {i}/{len(sections)}")
scores = self._analyze_section(headline, section)
section_scores.append(scores)
# Aggregate scores across sections
# Use max contradiction (if any section strongly contradicts, that's important)
# Use mean entailment (overall support across sections)
# Use mean neutral (general neutral tone across sections)
entailment_score = np.mean([s.get('ENTAILMENT', 0) for s in section_scores])
contradiction_score = np.max([s.get('CONTRADICTION', 0) for s in section_scores])
neutral_score = np.mean([s.get('NEUTRAL', 0) for s in section_scores])
logger.info("\nAggregated Scores Across Sections:")
logger.info("-"*30)
logger.info(f"Mean Entailment: {entailment_score:.3f}")
logger.info(f"Max Contradiction: {contradiction_score:.3f}")
logger.info(f"Mean Neutral: {neutral_score:.3f}")
else:
# Single section analysis
scores = self._analyze_section(headline, content)
entailment_score = scores.get('ENTAILMENT', 0)
contradiction_score = scores.get('CONTRADICTION', 0)
neutral_score = scores.get('NEUTRAL', 0)
# Compute final consistency score
final_score = (
(entailment_score * 0.6) + # Base score from entailment
(neutral_score * 0.3) + # Neutral is acceptable
((1 - contradiction_score) * 0.1) # Small penalty for contradiction
) * 100
# Log final results
logger.info("\nFinal Analysis Results:")
logger.info("-"*30)
logger.info(f"Headline: {headline}")
logger.info(f"Content Length: {content_tokens} tokens")
logger.info("\nFinal Scores:")
logger.info(f"{'Entailment:':<15} {entailment_score:.3f}")
logger.info(f"{'Neutral:':<15} {neutral_score:.3f}")
logger.info(f"{'Contradiction:':<15} {contradiction_score:.3f}")
logger.info(f"\nFinal Score: {final_score:.1f}%")
logger.info("="*50 + "\n")
return {
"headline_vs_content_score": round(final_score, 1),
"entailment_score": round(entailment_score, 2),
"contradiction_score": round(contradiction_score, 2),
"contradictory_phrases": scores.get('flagged_phrases', [])
}
except Exception as e:
logger.error("\nHEADLINE ANALYSIS ERROR")
logger.error("-"*30)
logger.error(f"Error Type: {type(e).__name__}")
logger.error(f"Error Message: {str(e)}")
logger.error("Stack Trace:", exc_info=True)
logger.error("="*50 + "\n")
return {
"headline_vs_content_score": 0,
"entailment_score": 0,
"contradiction_score": 0,
"contradictory_phrases": []
} |