waqasali1707 commited on
Commit
ba07b60
·
verified ·
1 Parent(s): aacf5a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -35
app.py CHANGED
@@ -1,45 +1,29 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
- from pydrive.auth import GoogleAuth
5
- from pydrive.drive import GoogleDrive
6
- import os
7
- from io import BytesIO
8
- from zipfile import ZipFile
9
-
10
- # Initialize Google Auth and Drive
11
  gauth = GoogleAuth()
12
- gauth.LocalWebserverAuth() # Authenticates and opens a browser window
13
  drive = GoogleDrive(gauth)
14
 
15
- # Streamlit UI
16
- st.title("Text Summarizer")
17
-
18
- # Enter the file ID of your model.zip on Google Drive
19
- model_file_id = st.text_input("Enter the Google Drive file ID of the model.zip")
20
 
21
- if model_file_id:
22
- try:
23
- # Download the file from Google Drive
24
- downloaded = drive.CreateFile({'id': model_file_id}).GetContentString()
25
-
26
- # Load the model from the downloaded zip file
27
- with ZipFile(BytesIO(downloaded.encode()), 'r') as zip_ref:
28
- zip_ref.extractall("model_directory")
29
-
30
- # Load the model from the extracted directory
31
- model_path = "model_directory"
32
- tokenizer = AutoTokenizer.from_pretrained(model_path)
33
- model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
34
- st.success("Model loaded successfully!")
35
- except Exception as e:
36
- st.error(f"Failed to load model: {e}")
37
 
38
- # Text area for input
 
39
  text = st.text_area("Enter the text to generate its Summary:")
40
 
41
  # Configuration for generation
42
- generation_config = {'max_length': 100, 'do_sample': True, 'temperature': 0.7}
43
 
44
  if text:
45
  try:
@@ -48,14 +32,14 @@ if text:
48
 
49
  # Generate output
50
  with torch.no_grad():
51
- model_output = model.generate(inputs_encoded["input_ids"], **generation_config)[0]
52
 
53
  # Decode output
54
  output = tokenizer.decode(model_output, skip_special_tokens=True)
55
 
56
- # Display results
57
  with st.expander("Output", expanded=True):
58
  st.write(output)
59
 
60
  except Exception as e:
61
- st.error(f"An error occurred during summarization: {e}")
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig
4
+ from pydrive2.auth import GoogleAuth
5
+ from pydrive2.drive import GoogleDrive
6
+
7
+ # Authenticate and create the PyDrive client.
 
 
 
8
  gauth = GoogleAuth()
9
+ gauth.LocalWebserverAuth() # Creates a local webserver and automatically handles authentication.
10
  drive = GoogleDrive(gauth)
11
 
12
+ # Update this path to your local path where the model is stored
13
+ model_path = '/content/drive/My Drive/bart-base'
 
 
 
14
 
15
+ try:
16
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
17
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
18
+ except Exception as e:
19
+ st.error(f"Failed to load model: {e}")
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Streamlit UI
22
+ st.title("Text Summarizer")
23
  text = st.text_area("Enter the text to generate its Summary:")
24
 
25
  # Configuration for generation
26
+ generation_config = GenerationConfig(max_new_tokens=100, do_sample=True, temperature=0.7)
27
 
28
  if text:
29
  try:
 
32
 
33
  # Generate output
34
  with torch.no_grad():
35
+ model_output = model.generate(inputs_encoded["input_ids"], generation_config=generation_config)[0]
36
 
37
  # Decode output
38
  output = tokenizer.decode(model_output, skip_special_tokens=True)
39
 
40
+ # Display results in a box with a title
41
  with st.expander("Output", expanded=True):
42
  st.write(output)
43
 
44
  except Exception as e:
45
+ st.error(f"An error occurred during summarization: {e}")