Spaces:
Sleeping
Sleeping
import re | |
import numpy as np | |
import json | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer | |
from sklearn.cluster import AgglomerativeClustering | |
from sklearn.metrics.pairwise import cosine_distances | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
import os | |
import gradio as gr | |
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") | |
sentence_model = SentenceTransformer('all-MiniLM-L6-v2') | |
max_tokens = 4000 | |
def clean_text(text): | |
text = re.sub(r'\[speaker_\d+\]', '', text) | |
text = re.sub(r'\s+', ' ', text).strip() | |
return text | |
def split_text_with_modernbert_tokenizer(text): | |
text = clean_text(text) | |
rough_splits = re.split(r'(?<=[.!?])\s+', text) | |
segments = [] | |
current_segment = "" | |
current_token_count = 0 | |
for sentence in rough_splits: | |
if not sentence.strip(): | |
continue | |
sentence_tokens = len(tokenizer.encode(sentence, add_special_tokens=False)) | |
if (current_token_count + sentence_tokens > 100 or | |
re.search(r'[.!?]$', current_segment.strip())): | |
if current_segment: | |
segments.append(current_segment.strip()) | |
current_segment = sentence | |
current_token_count = sentence_tokens | |
else: | |
current_segment += " " + sentence if current_segment else sentence | |
current_token_count += sentence_tokens | |
if current_segment: | |
segments.append(current_segment.strip()) | |
refined_segments = [] | |
for segment in segments: | |
if len(segment.split()) < 3: | |
if refined_segments: | |
refined_segments[-1] += ' ' + segment | |
else: | |
refined_segments.append(segment) | |
continue | |
tokens = tokenizer.tokenize(segment) | |
if len(tokens) < 50: | |
refined_segments.append(segment) | |
continue | |
break_indices = [i for i, token in enumerate(tokens) | |
if ('.' in token or ',' in token or '?' in token or '!' in token) | |
and i < len(tokens) - 1] | |
if not break_indices or break_indices[-1] < len(tokens) * 0.7: | |
refined_segments.append(segment) | |
continue | |
mid_idx = break_indices[len(break_indices) // 2] | |
first_half = tokenizer.convert_tokens_to_string(tokens[:mid_idx+1]) | |
second_half = tokenizer.convert_tokens_to_string(tokens[mid_idx+1:]) | |
refined_segments.append(first_half.strip()) | |
refined_segments.append(second_half.strip()) | |
return refined_segments | |
def semantic_chunking(text): | |
segments = split_text_with_modernbert_tokenizer(text) | |
segment_embeddings = sentence_model.encode(segments) | |
distances = cosine_distances(segment_embeddings) | |
agg_clustering = AgglomerativeClustering( | |
n_clusters=None, | |
distance_threshold=1, | |
metric='precomputed', | |
linkage='average' | |
) | |
clusters = agg_clustering.fit_predict(distances) | |
# Group segments by cluster | |
cluster_groups = {} | |
for i, cluster_id in enumerate(clusters): | |
if cluster_id not in cluster_groups: | |
cluster_groups[cluster_id] = [] | |
cluster_groups[cluster_id].append(segments[i]) | |
chunks = [] | |
for cluster_id in sorted(cluster_groups.keys()): | |
cluster_segments = cluster_groups[cluster_id] | |
current_chunk = [] | |
current_token_count = 0 | |
for segment in cluster_segments: | |
segment_tokens = len(tokenizer.encode(segment, truncation=True, add_special_tokens=True)) | |
if segment_tokens > max_tokens: | |
if current_chunk: | |
chunks.append(" ".join(current_chunk)) | |
current_chunk = [] | |
current_token_count = 0 | |
chunks.append(segment) | |
continue | |
if current_token_count + segment_tokens > max_tokens and current_chunk: | |
chunks.append(" ".join(current_chunk)) | |
current_chunk = [segment] | |
current_token_count = segment_tokens | |
else: | |
current_chunk.append(segment) | |
current_token_count += segment_tokens | |
if current_chunk: | |
chunks.append(" ".join(current_chunk)) | |
if len(chunks) > 1: | |
chunk_embeddings = sentence_model.encode(chunks) | |
chunk_similarities = 1 - cosine_distances(chunk_embeddings) | |
i = 0 | |
while i < len(chunks) - 1: | |
j = i + 1 | |
if chunk_similarities[i, j] > 0.75: | |
combined = chunks[i] + " " + chunks[j] | |
combined_tokens = len(tokenizer.encode(combined, truncation=True, add_special_tokens=True)) | |
if combined_tokens <= max_tokens: | |
# Merge chunks | |
chunks[i] = combined | |
chunks.pop(j) | |
chunk_embeddings = sentence_model.encode(chunks) | |
chunk_similarities = 1 - cosine_distances(chunk_embeddings) | |
else: | |
i += 1 | |
else: | |
i += 1 | |
return chunks | |
def analyze_segment_with_gemini(cluster_text, is_full_text=False): | |
llm = ChatGoogleGenerativeAI( | |
model="gemini-1.5-flash", | |
temperature=0.7, | |
max_tokens=None, | |
timeout=None, | |
max_retries=3 | |
) | |
if is_full_text: | |
prompt = f""" | |
Analyze the following text (likely a transcript or document) and: | |
1. First, identify distinct segments or topics within the text | |
2. For each segment/topic you identify: | |
- Provide a concise topic name (3-5 words) | |
- List 3-5 key concepts discussed in that segment | |
- Write a brief summary of that segment (3-5 sentences) | |
- Create 5 quiz questions based DIRECTLY on the content in that segment | |
For each quiz question: | |
- Create one correct answer that comes DIRECTLY from the text | |
- Create two plausible but incorrect answers | |
- IMPORTANT: Ensure all answer options have similar length (± 3 words) | |
- Ensure the correct answer is clearly indicated | |
- The correct answer should be subtly embedded, ensuring that length or wording style does not make it obvious. The incorrect answers should be semantically close and require careful reading to distinguish from the correct one. | |
Text: | |
{cluster_text} | |
Format your response as JSON with the following structure: | |
{{ | |
"segments": [ | |
{{ | |
"topic_name": "Name of segment 1", | |
"key_concepts": ["concept1", "concept2", "concept3"], | |
"summary": "Brief summary of this segment.", | |
"quiz_questions": [ | |
{{ | |
"question": "Question text?", | |
"options": [ | |
{{ | |
"text": "Option A", | |
"correct": false | |
}}, | |
{{ | |
"text": "Option B", | |
"correct": true | |
}}, | |
{{ | |
"text": "Option C", | |
"correct": false | |
}} | |
] | |
}}, | |
// More questions... | |
] | |
}}, | |
// More segments... | |
] | |
}} | |
""" | |
else: | |
prompt = f""" | |
Analyze the following text segment and provide: | |
1. A concise topic name (3-5 words) | |
2. 3-5 key concepts discussed | |
3. A brief summary (6-7 sentences) | |
4. Create 5 quiz questions based DIRECTLY on the text content (not from your summary) | |
For each quiz question: | |
- Create one correct answer that comes DIRECTLY from the text | |
- Create two plausible but incorrect answers | |
- IMPORTANT: Ensure all answer options have similar length (± 3 words) | |
- Ensure the correct answer is clearly indicated | |
- The correct answer should be subtly embedded, ensuring that length or wording style does not make it obvious. The incorrect answers should be semantically close and require careful reading to distinguish from the correct one. | |
Text segment: | |
{cluster_text} | |
Format your response as JSON with the following structure: | |
{{ | |
"topic_name": "Name of the topic", | |
"key_concepts": ["concept1", "concept2", "concept3"], | |
"summary": "Brief summary of the text segment.", | |
"quiz_questions": [ | |
{{ | |
"question": "Question text?", | |
"options": [ | |
{{ | |
"text": "Option A", | |
"correct": false | |
}}, | |
{{ | |
"text": "Option B", | |
"correct": true | |
}}, | |
{{ | |
"text": "Option C", | |
"correct": false | |
}} | |
] | |
}}, | |
// More questions... | |
] | |
}} | |
""" | |
response = llm.invoke(prompt) | |
response_text = response.content | |
try: | |
json_match = re.search(r'\{[\s\S]*\}', response_text) | |
if json_match: | |
response_json = json.loads(json_match.group(0)) | |
else: | |
response_json = json.loads(response_text) | |
return response_json | |
except json.JSONDecodeError as e: | |
print(f"Error parsing JSON response: {e}") | |
print(f"Raw response: {response_text}") | |
if is_full_text: | |
return { | |
"segments": [ | |
{ | |
"topic_name": "JSON Parsing Error", | |
"key_concepts": ["Error in response format"], | |
"summary": f"Could not parse the API response. Raw text: {response_text[:200]}...", | |
"quiz_questions": [] | |
} | |
] | |
} | |
else: | |
return { | |
"topic_name": "JSON Parsing Error", | |
"key_concepts": ["Error in response format"], | |
"summary": f"Could not parse the API response. Raw text: {response_text[:200]}...", | |
"quiz_questions": [] | |
} | |
def process_document_with_quiz(text): | |
token_count = len(tokenizer.encode(text)) | |
print(f"Text contains {token_count} tokens") | |
if token_count < 12000: | |
print("Text is short enough to analyze directly without text segmentation") | |
full_analysis = analyze_segment_with_gemini(text, is_full_text=True) | |
results = [] | |
if "segments" in full_analysis: | |
for i, segment in enumerate(full_analysis["segments"]): | |
segment["segment_number"] = i + 1 | |
segment["segment_text"] = "Segment identified by Gemini" | |
results.append(segment) | |
print(f"Gemini identified {len(results)} segments in the text") | |
else: | |
print("Unexpected response format from Gemini") | |
results = [full_analysis] | |
return results | |
chunks = semantic_chunking(text) | |
print(f"{len(chunks)} semantic chunks were found\n") | |
results = [] | |
for i, chunk in enumerate(chunks): | |
print(f"Analyzing segment {i+1}/{len(chunks)}...") | |
analysis = analyze_segment_with_gemini(chunk, is_full_text=False) | |
analysis["segment_number"] = i + 1 | |
analysis["segment_text"] = chunk | |
results.append(analysis) | |
print(f"Completed analysis of segment {i+1}: {analysis['topic_name']}") | |
return results | |
def save_results_to_file(results, output_file="analysis_results.json"): | |
with open(output_file, "w", encoding="utf-8") as f: | |
json.dump(results, f, indent=2, ensure_ascii=False) | |
print(f"Results saved to {output_file}") | |
def format_quiz_for_display(results): | |
output = [] | |
for segment_result in results: | |
segment_num = segment_result["segment_number"] | |
topic = segment_result["topic_name"] | |
output.append(f"\n\n{'='*40}") | |
output.append(f"SEGMENT {segment_num}: {topic}") | |
output.append(f"{'='*40}\n") | |
output.append("KEY CONCEPTS:") | |
for concept in segment_result["key_concepts"]: | |
output.append(f"• {concept}") | |
output.append("\nSUMMARY:") | |
output.append(segment_result["summary"]) | |
output.append("\nQUIZ QUESTIONS:") | |
for i, q in enumerate(segment_result["quiz_questions"]): | |
output.append(f"\n{i+1}. {q['question']}") | |
for j, option in enumerate(q['options']): | |
letter = chr(97 + j).upper() | |
correct_marker = " ✓" if option["correct"] else "" | |
output.append(f" {letter}. {option['text']}{correct_marker}") | |
return "\n".join(output) | |
def analyze_document(document_text: str, api_key: str) -> str: | |
os.environ["GOOGLE_API_KEY"] = api_key | |
try: | |
results = process_document_with_quiz(document_text) | |
formatted_output = format_quiz_for_display(results) | |
return formatted_output | |
except Exception as e: | |
return f"Error processing document: {str(e)}" | |
with gr.Blocks(title="Quiz Generator ") as app: | |
gr.Markdown("Quiz Generator") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
label="Input Text", | |
placeholder="Paste your document text here...", | |
lines=10 | |
) | |
api_key = gr.Textbox( | |
label="Gemini API Key", | |
placeholder="Enter your Gemini API key", | |
type="password" | |
) | |
analyze_btn = gr.Button("Analyze Document") | |
with gr.Column(): | |
output_results = gr.Textbox( | |
label="Analysis Results", | |
lines=20 | |
) | |
analyze_btn.click( | |
fn=analyze_document, | |
inputs=[input_text, api_key], | |
outputs=output_results | |
) | |
if __name__ == "__main__": | |
app.launch() |