Spaces:
Running
Running
Create summarization_app.py
Browse files- summarization_app.py +72 -0
summarization_app.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, T5ForConditionalGeneration, AutoModelForSeq2SeqLM
|
3 |
+
import streamlit as st
|
4 |
+
from summarizer import Summarizer
|
5 |
+
import nltk
|
6 |
+
nltk.download('punkt')
|
7 |
+
|
8 |
+
available_models = {
|
9 |
+
"IlyaGusev/rugpt3medium_sum_gazeta": "Russian Summarization (IlyaGusev/rugpt3medium_sum_gazeta)",
|
10 |
+
"Shahm/t5-small-german": "German Summarization (Shahm/t5-small-german)",
|
11 |
+
"Falconsai/medical_summarization": "English Summarization (Falconsai/medical_summarization)",
|
12 |
+
"sacreemure/med_t5_summ_ru":"Russian Medical Texts Summarization (sacreemure/med_t5_summ_ru)"
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
def hugging_face_summarize(article, model_name, num_sentences):
|
17 |
+
if "rugpt3medium" in model_name.lower():
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
19 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
20 |
+
input_ids = tokenizer(article, return_tensors='pt', max_length=400, truncation=True, padding=True)["input_ids"]
|
21 |
+
output_ids = model.generate(input_ids, max_new_tokens=300, repetition_penalty = 7.0, num_return_sequences=5, temperature = 0.7, top_k=50, early_stopping=True)[0]
|
22 |
+
summary = tokenizer.decode(output_ids, skip_special_tokens=True)
|
23 |
+
|
24 |
+
elif "medical" in model_name.lower():
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
26 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
27 |
+
input_ids= tokenizer(article, return_tensors='pt', max_length=504, truncation=True, padding=True)["input_ids"]
|
28 |
+
output_ids = model.generate(input_ids, max_new_tokens=500)
|
29 |
+
summary = tokenizer.decode(output_ids, skip_special_tokens=True)
|
30 |
+
|
31 |
+
elif "med_t5" in model_name.lower():
|
32 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
33 |
+
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
34 |
+
input_ids = tokenizer(article, return_tensors='pt', max_length=2048, truncation=True)["input_ids"]
|
35 |
+
output_ids = model.generate(input_ids, min_length=800, max_length=1000, repetition_penalty = 2.0, num_return_sequences=1, temperature = 0.7, top_k=50, early_stopping=True)[0]
|
36 |
+
summary = tokenizer.decode(output_ids, skip_special_tokens=True)
|
37 |
+
|
38 |
+
else:
|
39 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
40 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_fast=False)
|
41 |
+
inputs = tokenizer(article, return_tensors="pt", max_length=800, truncation=True, padding=True)
|
42 |
+
output_ids = model.generate(inputs.input_ids, max_new_tokens=100, num_return_sequences=1)
|
43 |
+
summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
44 |
+
|
45 |
+
summary_sentences = nltk.sent_tokenize(summary)
|
46 |
+
summary = ' '.join(summary_sentences[:num_sentences])
|
47 |
+
|
48 |
+
return summary
|
49 |
+
|
50 |
+
def main():
|
51 |
+
|
52 |
+
st.title("Суммаризиризатор медицинских текстов")
|
53 |
+
st.write("Вы можете выбрать модель суммаризации для русского, английского или немецкого")
|
54 |
+
|
55 |
+
selected_model = st.selectbox("Выберите модель:", list(available_models.values()))
|
56 |
+
|
57 |
+
article_text = st.text_area("Введите текст:")
|
58 |
+
|
59 |
+
num_sentences = st.slider("Выберите количество предложений в суммаризированном тексте:", min_value=1, max_value=10, value=3)
|
60 |
+
|
61 |
+
if st.button("Суммаризировать"):
|
62 |
+
if article_text:
|
63 |
+
model_name = [name for name, model in available_models.items() if model == selected_model][0]
|
64 |
+
summary = hugging_face_summarize(article_text, model_name, num_sentences)
|
65 |
+
|
66 |
+
st.subheader("Сокращенный текст:")
|
67 |
+
st.write(summary)
|
68 |
+
else:
|
69 |
+
st.warning("Пожалуйста, введите текст.")
|
70 |
+
|
71 |
+
if __name__ == "__main__":
|
72 |
+
main()
|