Garvitj commited on
Commit
c1fb255
·
verified ·
1 Parent(s): a55a8d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -12
app.py CHANGED
@@ -6,15 +6,18 @@ from sentence_transformers import SentenceTransformer, util
6
  from PIL import Image
7
  from typing import List
8
  import torch
9
- from transformers import BertTokenizer, BertModel
10
  import torch.nn.functional as F
11
- import language_tool_python # Import LanguageTool for grammar checking
12
 
13
  # Load pre-trained models
14
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
15
  bert_model = BertModel.from_pretrained('bert-base-uncased')
16
  sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
17
 
 
 
 
 
18
  # Initialize Groq client
19
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
20
 
@@ -24,9 +27,6 @@ system_prompt = {
24
  "content": "You are a useful assistant. You reply with efficient answers."
25
  }
26
 
27
- # Initialize grammar checker
28
- tool = language_tool_python.LanguageTool('en-US')
29
-
30
  async def chat_groq(message, history):
31
  messages = [system_prompt]
32
  for msg in history:
@@ -103,13 +103,22 @@ def calculate_sentence_similarity(text1, text2):
103
  embedding2 = sentence_model.encode(text2, convert_to_tensor=True)
104
  return util.pytorch_cos_sim(embedding1, embedding2).item()
105
 
106
- def check_grammar(student_answer):
107
- # Check grammar using LanguageTool
108
- matches = tool.check(student_answer)
109
- errors = len(matches)
 
 
 
 
 
 
 
 
 
110
 
111
- # Apply a penalty based on the number of grammar errors
112
- penalty = 1 - min(0.1 * errors, 0.5) # Maximum penalty is 50%
113
  return penalty
114
 
115
  def compare_answers(student_answer, teacher_answer):
@@ -120,7 +129,7 @@ def compare_answers(student_answer, teacher_answer):
120
  semantic_similarity = (0.75 * bert_similarity + 0.25 * sentence_similarity)
121
 
122
  # Apply grammar penalty
123
- grammar_penalty = check_grammar(student_answer)
124
  final_similarity = semantic_similarity * grammar_penalty
125
 
126
  return final_similarity
 
6
  from PIL import Image
7
  from typing import List
8
  import torch
9
+ from transformers import BertTokenizer, BertModel, T5ForConditionalGeneration, T5Tokenizer
10
  import torch.nn.functional as F
 
11
 
12
  # Load pre-trained models
13
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
14
  bert_model = BertModel.from_pretrained('bert-base-uncased')
15
  sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
16
 
17
+ # Load the pre-trained T5 model and tokenizer for grammar error detection
18
+ grammar_model = T5ForConditionalGeneration.from_pretrained('t5-base')
19
+ grammar_tokenizer = T5Tokenizer.from_pretrained('t5-base')
20
+
21
  # Initialize Groq client
22
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
23
 
 
27
  "content": "You are a useful assistant. You reply with efficient answers."
28
  }
29
 
 
 
 
30
  async def chat_groq(message, history):
31
  messages = [system_prompt]
32
  for msg in history:
 
103
  embedding2 = sentence_model.encode(text2, convert_to_tensor=True)
104
  return util.pytorch_cos_sim(embedding1, embedding2).item()
105
 
106
+ # Grammar detection and penalization using T5 model
107
+ def detect_grammar_errors(text):
108
+ input_text = f"grammar: {text}"
109
+ inputs = grammar_tokenizer.encode(input_text, return_tensors='pt', max_length=512, truncation=True)
110
+ outputs = grammar_model.generate(inputs, max_length=512, num_beams=4, early_stopping=True)
111
+ grammar_analysis = grammar_tokenizer.decode(outputs[0], skip_special_tokens=True)
112
+
113
+ # Count the number of errors based on specific indicators (customize based on analysis)
114
+ error_count = grammar_analysis.count('error') # Use your own criteria
115
+ return error_count
116
+
117
+ def penalize_for_grammar(student_answer):
118
+ grammar_errors = detect_grammar_errors(student_answer)
119
 
120
+ # Apply a penalty based on the number of grammar errors (max 50% penalty)
121
+ penalty = max(0.5, 1 - 0.05 * grammar_errors)
122
  return penalty
123
 
124
  def compare_answers(student_answer, teacher_answer):
 
129
  semantic_similarity = (0.75 * bert_similarity + 0.25 * sentence_similarity)
130
 
131
  # Apply grammar penalty
132
+ grammar_penalty = penalize_for_grammar(student_answer)
133
  final_similarity = semantic_similarity * grammar_penalty
134
 
135
  return final_similarity