QuizGenerator / app.py
MrSimple01's picture
Create app.py
b29c0f7 verified
raw
history blame
15.3 kB
import re
import numpy as np
import json
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import cosine_distances
from langchain_google_genai import ChatGoogleGenerativeAI
import os
import gradio as gr
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
max_tokens = 4000
def clean_text(text):
text = re.sub(r'\[speaker_\d+\]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
def split_text_with_modernbert_tokenizer(text):
text = clean_text(text)
rough_splits = re.split(r'(?<=[.!?])\s+', text)
segments = []
current_segment = ""
current_token_count = 0
for sentence in rough_splits:
if not sentence.strip():
continue
sentence_tokens = len(tokenizer.encode(sentence, add_special_tokens=False))
if (current_token_count + sentence_tokens > 100 or
re.search(r'[.!?]$', current_segment.strip())):
if current_segment:
segments.append(current_segment.strip())
current_segment = sentence
current_token_count = sentence_tokens
else:
current_segment += " " + sentence if current_segment else sentence
current_token_count += sentence_tokens
if current_segment:
segments.append(current_segment.strip())
refined_segments = []
for segment in segments:
if len(segment.split()) < 3:
if refined_segments:
refined_segments[-1] += ' ' + segment
else:
refined_segments.append(segment)
continue
tokens = tokenizer.tokenize(segment)
if len(tokens) < 50:
refined_segments.append(segment)
continue
break_indices = [i for i, token in enumerate(tokens)
if ('.' in token or ',' in token or '?' in token or '!' in token)
and i < len(tokens) - 1]
if not break_indices or break_indices[-1] < len(tokens) * 0.7:
refined_segments.append(segment)
continue
mid_idx = break_indices[len(break_indices) // 2]
first_half = tokenizer.convert_tokens_to_string(tokens[:mid_idx+1])
second_half = tokenizer.convert_tokens_to_string(tokens[mid_idx+1:])
refined_segments.append(first_half.strip())
refined_segments.append(second_half.strip())
return refined_segments
def semantic_chunking(text):
segments = split_text_with_modernbert_tokenizer(text)
segment_embeddings = sentence_model.encode(segments)
distances = cosine_distances(segment_embeddings)
agg_clustering = AgglomerativeClustering(
n_clusters=None,
distance_threshold=1,
metric='precomputed',
linkage='average'
)
clusters = agg_clustering.fit_predict(distances)
# Group segments by cluster
cluster_groups = {}
for i, cluster_id in enumerate(clusters):
if cluster_id not in cluster_groups:
cluster_groups[cluster_id] = []
cluster_groups[cluster_id].append(segments[i])
chunks = []
for cluster_id in sorted(cluster_groups.keys()):
cluster_segments = cluster_groups[cluster_id]
current_chunk = []
current_token_count = 0
for segment in cluster_segments:
segment_tokens = len(tokenizer.encode(segment, truncation=True, add_special_tokens=True))
if segment_tokens > max_tokens:
if current_chunk:
chunks.append(" ".join(current_chunk))
current_chunk = []
current_token_count = 0
chunks.append(segment)
continue
if current_token_count + segment_tokens > max_tokens and current_chunk:
chunks.append(" ".join(current_chunk))
current_chunk = [segment]
current_token_count = segment_tokens
else:
current_chunk.append(segment)
current_token_count += segment_tokens
if current_chunk:
chunks.append(" ".join(current_chunk))
if len(chunks) > 1:
chunk_embeddings = sentence_model.encode(chunks)
chunk_similarities = 1 - cosine_distances(chunk_embeddings)
i = 0
while i < len(chunks) - 1:
j = i + 1
if chunk_similarities[i, j] > 0.75:
combined = chunks[i] + " " + chunks[j]
combined_tokens = len(tokenizer.encode(combined, truncation=True, add_special_tokens=True))
if combined_tokens <= max_tokens:
# Merge chunks
chunks[i] = combined
chunks.pop(j)
chunk_embeddings = sentence_model.encode(chunks)
chunk_similarities = 1 - cosine_distances(chunk_embeddings)
else:
i += 1
else:
i += 1
return chunks
def analyze_segment_with_gemini(cluster_text, is_full_text=False):
llm = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
temperature=0.7,
max_tokens=None,
timeout=None,
max_retries=3
)
if is_full_text:
prompt = f"""
Analyze the following text (likely a transcript or document) and:
1. First, identify distinct segments or topics within the text
2. For each segment/topic you identify:
- Provide a concise topic name (3-5 words)
- List 3-5 key concepts discussed in that segment
- Write a brief summary of that segment (3-5 sentences)
- Create 5 quiz questions based DIRECTLY on the content in that segment
For each quiz question:
- Create one correct answer that comes DIRECTLY from the text
- Create two plausible but incorrect answers
- IMPORTANT: Ensure all answer options have similar length (± 3 words)
- Ensure the correct answer is clearly indicated
- The correct answer should be subtly embedded, ensuring that length or wording style does not make it obvious. The incorrect answers should be semantically close and require careful reading to distinguish from the correct one.
Text:
{cluster_text}
Format your response as JSON with the following structure:
{{
"segments": [
{{
"topic_name": "Name of segment 1",
"key_concepts": ["concept1", "concept2", "concept3"],
"summary": "Brief summary of this segment.",
"quiz_questions": [
{{
"question": "Question text?",
"options": [
{{
"text": "Option A",
"correct": false
}},
{{
"text": "Option B",
"correct": true
}},
{{
"text": "Option C",
"correct": false
}}
]
}},
// More questions...
]
}},
// More segments...
]
}}
"""
else:
prompt = f"""
Analyze the following text segment and provide:
1. A concise topic name (3-5 words)
2. 3-5 key concepts discussed
3. A brief summary (6-7 sentences)
4. Create 5 quiz questions based DIRECTLY on the text content (not from your summary)
For each quiz question:
- Create one correct answer that comes DIRECTLY from the text
- Create two plausible but incorrect answers
- IMPORTANT: Ensure all answer options have similar length (± 3 words)
- Ensure the correct answer is clearly indicated
- The correct answer should be subtly embedded, ensuring that length or wording style does not make it obvious. The incorrect answers should be semantically close and require careful reading to distinguish from the correct one.
Text segment:
{cluster_text}
Format your response as JSON with the following structure:
{{
"topic_name": "Name of the topic",
"key_concepts": ["concept1", "concept2", "concept3"],
"summary": "Brief summary of the text segment.",
"quiz_questions": [
{{
"question": "Question text?",
"options": [
{{
"text": "Option A",
"correct": false
}},
{{
"text": "Option B",
"correct": true
}},
{{
"text": "Option C",
"correct": false
}}
]
}},
// More questions...
]
}}
"""
response = llm.invoke(prompt)
response_text = response.content
try:
json_match = re.search(r'\{[\s\S]*\}', response_text)
if json_match:
response_json = json.loads(json_match.group(0))
else:
response_json = json.loads(response_text)
return response_json
except json.JSONDecodeError as e:
print(f"Error parsing JSON response: {e}")
print(f"Raw response: {response_text}")
if is_full_text:
return {
"segments": [
{
"topic_name": "JSON Parsing Error",
"key_concepts": ["Error in response format"],
"summary": f"Could not parse the API response. Raw text: {response_text[:200]}...",
"quiz_questions": []
}
]
}
else:
return {
"topic_name": "JSON Parsing Error",
"key_concepts": ["Error in response format"],
"summary": f"Could not parse the API response. Raw text: {response_text[:200]}...",
"quiz_questions": []
}
def process_document_with_quiz(text):
token_count = len(tokenizer.encode(text))
print(f"Text contains {token_count} tokens")
if token_count < 12000:
print("Text is short enough to analyze directly without text segmentation")
full_analysis = analyze_segment_with_gemini(text, is_full_text=True)
results = []
if "segments" in full_analysis:
for i, segment in enumerate(full_analysis["segments"]):
segment["segment_number"] = i + 1
segment["segment_text"] = "Segment identified by Gemini"
results.append(segment)
print(f"Gemini identified {len(results)} segments in the text")
else:
print("Unexpected response format from Gemini")
results = [full_analysis]
return results
chunks = semantic_chunking(text)
print(f"{len(chunks)} semantic chunks were found\n")
results = []
for i, chunk in enumerate(chunks):
print(f"Analyzing segment {i+1}/{len(chunks)}...")
analysis = analyze_segment_with_gemini(chunk, is_full_text=False)
analysis["segment_number"] = i + 1
analysis["segment_text"] = chunk
results.append(analysis)
print(f"Completed analysis of segment {i+1}: {analysis['topic_name']}")
return results
def save_results_to_file(results, output_file="analysis_results.json"):
with open(output_file, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"Results saved to {output_file}")
def format_quiz_for_display(results):
output = []
for segment_result in results:
segment_num = segment_result["segment_number"]
topic = segment_result["topic_name"]
output.append(f"\n\n{'='*40}")
output.append(f"SEGMENT {segment_num}: {topic}")
output.append(f"{'='*40}\n")
output.append("KEY CONCEPTS:")
for concept in segment_result["key_concepts"]:
output.append(f"• {concept}")
output.append("\nSUMMARY:")
output.append(segment_result["summary"])
output.append("\nQUIZ QUESTIONS:")
for i, q in enumerate(segment_result["quiz_questions"]):
output.append(f"\n{i+1}. {q['question']}")
for j, option in enumerate(q['options']):
letter = chr(97 + j).upper()
correct_marker = " ✓" if option["correct"] else ""
output.append(f" {letter}. {option['text']}{correct_marker}")
return "\n".join(output)
def analyze_document(document_text: str, api_key: str) -> str:
os.environ["GOOGLE_API_KEY"] = api_key
try:
results = process_document_with_quiz(document_text)
formatted_output = format_quiz_for_display(results)
return formatted_output
except Exception as e:
return f"Error processing document: {str(e)}"
with gr.Blocks(title="Quiz Generator ") as app:
gr.Markdown("Quiz Generator")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Input Text",
placeholder="Paste your document text here...",
lines=10
)
api_key = gr.Textbox(
label="Gemini API Key",
placeholder="Enter your Gemini API key",
type="password"
)
analyze_btn = gr.Button("Analyze Document")
with gr.Column():
output_results = gr.Textbox(
label="Analysis Results",
lines=20
)
analyze_btn.click(
fn=analyze_document,
inputs=[input_text, api_key],
outputs=output_results
)
if __name__ == "__main__":
app.launch()