sacreemure commited on
Commit
ce746b1
·
verified ·
1 Parent(s): 3a4742a

Create summarization_app.py

Browse files
Files changed (1) hide show
  1. summarization_app.py +72 -0
summarization_app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, T5ForConditionalGeneration, AutoModelForSeq2SeqLM
3
+ import streamlit as st
4
+ from summarizer import Summarizer
5
+ import nltk
6
+ nltk.download('punkt')
7
+
8
+ available_models = {
9
+ "IlyaGusev/rugpt3medium_sum_gazeta": "Russian Summarization (IlyaGusev/rugpt3medium_sum_gazeta)",
10
+ "Shahm/t5-small-german": "German Summarization (Shahm/t5-small-german)",
11
+ "Falconsai/medical_summarization": "English Summarization (Falconsai/medical_summarization)",
12
+ "sacreemure/med_t5_summ_ru":"Russian Medical Texts Summarization (sacreemure/med_t5_summ_ru)"
13
+ }
14
+
15
+
16
+ def hugging_face_summarize(article, model_name, num_sentences):
17
+ if "rugpt3medium" in model_name.lower():
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ model = AutoModelForCausalLM.from_pretrained(model_name)
20
+ input_ids = tokenizer(article, return_tensors='pt', max_length=400, truncation=True, padding=True)["input_ids"]
21
+ output_ids = model.generate(input_ids, max_new_tokens=300, repetition_penalty = 7.0, num_return_sequences=5, temperature = 0.7, top_k=50, early_stopping=True)[0]
22
+ summary = tokenizer.decode(output_ids, skip_special_tokens=True)
23
+
24
+ elif "medical" in model_name.lower():
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
27
+ input_ids= tokenizer(article, return_tensors='pt', max_length=504, truncation=True, padding=True)["input_ids"]
28
+ output_ids = model.generate(input_ids, max_new_tokens=500)
29
+ summary = tokenizer.decode(output_ids, skip_special_tokens=True)
30
+
31
+ elif "med_t5" in model_name.lower():
32
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
33
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
34
+ input_ids = tokenizer(article, return_tensors='pt', max_length=2048, truncation=True)["input_ids"]
35
+ output_ids = model.generate(input_ids, min_length=800, max_length=1000, repetition_penalty = 2.0, num_return_sequences=1, temperature = 0.7, top_k=50, early_stopping=True)[0]
36
+ summary = tokenizer.decode(output_ids, skip_special_tokens=True)
37
+
38
+ else:
39
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
40
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_fast=False)
41
+ inputs = tokenizer(article, return_tensors="pt", max_length=800, truncation=True, padding=True)
42
+ output_ids = model.generate(inputs.input_ids, max_new_tokens=100, num_return_sequences=1)
43
+ summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
44
+
45
+ summary_sentences = nltk.sent_tokenize(summary)
46
+ summary = ' '.join(summary_sentences[:num_sentences])
47
+
48
+ return summary
49
+
50
+ def main():
51
+
52
+ st.title("Суммаризиризатор медицинских текстов")
53
+ st.write("Вы можете выбрать модель суммаризации для русского, английского или немецкого")
54
+
55
+ selected_model = st.selectbox("Выберите модель:", list(available_models.values()))
56
+
57
+ article_text = st.text_area("Введите текст:")
58
+
59
+ num_sentences = st.slider("Выберите количество предложений в суммаризированном тексте:", min_value=1, max_value=10, value=3)
60
+
61
+ if st.button("Суммаризировать"):
62
+ if article_text:
63
+ model_name = [name for name, model in available_models.items() if model == selected_model][0]
64
+ summary = hugging_face_summarize(article_text, model_name, num_sentences)
65
+
66
+ st.subheader("Сокращенный текст:")
67
+ st.write(summary)
68
+ else:
69
+ st.warning("Пожалуйста, введите текст.")
70
+
71
+ if __name__ == "__main__":
72
+ main()