Update app.py
Browse files
app.py
CHANGED
@@ -6,15 +6,18 @@ from sentence_transformers import SentenceTransformer, util
|
|
6 |
from PIL import Image
|
7 |
from typing import List
|
8 |
import torch
|
9 |
-
from transformers import BertTokenizer, BertModel
|
10 |
import torch.nn.functional as F
|
11 |
-
import language_tool_python # Import LanguageTool for grammar checking
|
12 |
|
13 |
# Load pre-trained models
|
14 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
15 |
bert_model = BertModel.from_pretrained('bert-base-uncased')
|
16 |
sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
17 |
|
|
|
|
|
|
|
|
|
18 |
# Initialize Groq client
|
19 |
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
20 |
|
@@ -24,9 +27,6 @@ system_prompt = {
|
|
24 |
"content": "You are a useful assistant. You reply with efficient answers."
|
25 |
}
|
26 |
|
27 |
-
# Initialize grammar checker
|
28 |
-
tool = language_tool_python.LanguageTool('en-US')
|
29 |
-
|
30 |
async def chat_groq(message, history):
|
31 |
messages = [system_prompt]
|
32 |
for msg in history:
|
@@ -103,13 +103,22 @@ def calculate_sentence_similarity(text1, text2):
|
|
103 |
embedding2 = sentence_model.encode(text2, convert_to_tensor=True)
|
104 |
return util.pytorch_cos_sim(embedding1, embedding2).item()
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
-
# Apply a penalty based on the number of grammar errors
|
112 |
-
penalty = 1 -
|
113 |
return penalty
|
114 |
|
115 |
def compare_answers(student_answer, teacher_answer):
|
@@ -120,7 +129,7 @@ def compare_answers(student_answer, teacher_answer):
|
|
120 |
semantic_similarity = (0.75 * bert_similarity + 0.25 * sentence_similarity)
|
121 |
|
122 |
# Apply grammar penalty
|
123 |
-
grammar_penalty =
|
124 |
final_similarity = semantic_similarity * grammar_penalty
|
125 |
|
126 |
return final_similarity
|
|
|
6 |
from PIL import Image
|
7 |
from typing import List
|
8 |
import torch
|
9 |
+
from transformers import BertTokenizer, BertModel, T5ForConditionalGeneration, T5Tokenizer
|
10 |
import torch.nn.functional as F
|
|
|
11 |
|
12 |
# Load pre-trained models
|
13 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
14 |
bert_model = BertModel.from_pretrained('bert-base-uncased')
|
15 |
sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
16 |
|
17 |
+
# Load the pre-trained T5 model and tokenizer for grammar error detection
|
18 |
+
grammar_model = T5ForConditionalGeneration.from_pretrained('t5-base')
|
19 |
+
grammar_tokenizer = T5Tokenizer.from_pretrained('t5-base')
|
20 |
+
|
21 |
# Initialize Groq client
|
22 |
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
23 |
|
|
|
27 |
"content": "You are a useful assistant. You reply with efficient answers."
|
28 |
}
|
29 |
|
|
|
|
|
|
|
30 |
async def chat_groq(message, history):
|
31 |
messages = [system_prompt]
|
32 |
for msg in history:
|
|
|
103 |
embedding2 = sentence_model.encode(text2, convert_to_tensor=True)
|
104 |
return util.pytorch_cos_sim(embedding1, embedding2).item()
|
105 |
|
106 |
+
# Grammar detection and penalization using T5 model
|
107 |
+
def detect_grammar_errors(text):
|
108 |
+
input_text = f"grammar: {text}"
|
109 |
+
inputs = grammar_tokenizer.encode(input_text, return_tensors='pt', max_length=512, truncation=True)
|
110 |
+
outputs = grammar_model.generate(inputs, max_length=512, num_beams=4, early_stopping=True)
|
111 |
+
grammar_analysis = grammar_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
112 |
+
|
113 |
+
# Count the number of errors based on specific indicators (customize based on analysis)
|
114 |
+
error_count = grammar_analysis.count('error') # Use your own criteria
|
115 |
+
return error_count
|
116 |
+
|
117 |
+
def penalize_for_grammar(student_answer):
|
118 |
+
grammar_errors = detect_grammar_errors(student_answer)
|
119 |
|
120 |
+
# Apply a penalty based on the number of grammar errors (max 50% penalty)
|
121 |
+
penalty = max(0.5, 1 - 0.05 * grammar_errors)
|
122 |
return penalty
|
123 |
|
124 |
def compare_answers(student_answer, teacher_answer):
|
|
|
129 |
semantic_similarity = (0.75 * bert_similarity + 0.25 * sentence_similarity)
|
130 |
|
131 |
# Apply grammar penalty
|
132 |
+
grammar_penalty = penalize_for_grammar(student_answer)
|
133 |
final_similarity = semantic_similarity * grammar_penalty
|
134 |
|
135 |
return final_similarity
|