ardavey commited on
Commit
3d4f830
·
verified ·
1 Parent(s): bd36a06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -5
app.py CHANGED
@@ -1,15 +1,42 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
 
4
  # Load the text classification model
5
  classifier = pipeline('text-classification', model='ardavey/bert-base-ai-generated-text')
6
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  # Define a function for text classification
8
  def classify_text(text):
9
- predictions = classifier([text])
10
- label = 'AI Generated Text' if predictions[0]['label'] == 'LABEL_1' else 'Human Generated Text'
11
- score = predictions[0]['score']
12
- return f"Prediction: {label}, Score: {score:.4f}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Create a Gradio interface
15
  interface = gr.Interface(
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoTokenizer
3
 
4
  # Load the text classification model
5
  classifier = pipeline('text-classification', model='ardavey/bert-base-ai-generated-text')
6
 
7
+ # Load the tokenizer to handle text preprocessing
8
+ tokenizer = AutoTokenizer.from_pretrained('ardavey/bert-base-ai-generated-text')
9
+
10
+ # Define a function to truncate or split the input text
11
+ def preprocess_long_text(text, tokenizer, max_length=512):
12
+ # Tokenize the text
13
+ tokens = tokenizer.encode(text, add_special_tokens=True)
14
+ # Split into chunks of max_length
15
+ chunks = [tokens[i:i + max_length] for i in range(0, len(tokens), max_length)]
16
+ # Decode back to text
17
+ return [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in chunks]
18
+
19
  # Define a function for text classification
20
  def classify_text(text):
21
+ # Preprocess the text for long input
22
+ chunks = preprocess_long_text(text, tokenizer)
23
+
24
+ # Make predictions for each chunk
25
+ predictions = [classifier(chunk)[0] for chunk in chunks]
26
+
27
+ # Aggregate results (you can customize this logic)
28
+ ai_scores = [pred['score'] for pred in predictions if pred['label'] == 'LABEL_1']
29
+ human_scores = [pred['score'] for pred in predictions if pred['label'] == 'LABEL_0']
30
+
31
+ # Determine the overall prediction
32
+ if sum(ai_scores) > sum(human_scores):
33
+ label = "AI Generated Text"
34
+ score = sum(ai_scores) / len(ai_scores)
35
+ else:
36
+ label = "Human Generated Text"
37
+ score = sum(human_scores) / len(human_scores)
38
+
39
+ return f"Prediction: {label}, Average Score: {score:.4f}"
40
 
41
  # Create a Gradio interface
42
  interface = gr.Interface(