Spaces:
Sleeping
Sleeping
| import transformers | |
| import re | |
| from transformers import AutoTokenizer, pipeline | |
| import torch | |
| import gradio as gr | |
| import pandas as pd | |
| # Define the device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load models | |
| editorial_model = "PleIAs/Estienne" | |
| bibliography_model = "PleIAs/Bibliography-Formatter" | |
| editorial_classifier = pipeline( | |
| "token-classification", model=editorial_model, aggregation_strategy="simple", device=device | |
| ) | |
| bibliography_classifier = pipeline( | |
| "token-classification", model=bibliography_model, aggregation_strategy="simple", device=device | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(editorial_model, model_max_length=512) | |
| # Helper functions | |
| def preprocess_text(text): | |
| text = re.sub(r'<[^>]+>', '', text) | |
| text = re.sub(r'\n', ' ', text) | |
| text = re.sub(r'\s+', ' ', text) | |
| return text.strip() | |
| def split_text(text, max_tokens=500): | |
| parts = text.split("\n") | |
| chunks = [] | |
| current_chunk = "" | |
| for part in parts: | |
| temp_chunk = current_chunk + "\n" + part if current_chunk else part | |
| num_tokens = len(tokenizer.tokenize(temp_chunk)) | |
| if num_tokens <= max_tokens: | |
| current_chunk = temp_chunk | |
| else: | |
| if current_chunk: | |
| chunks.append(current_chunk) | |
| current_chunk = part | |
| if current_chunk: | |
| chunks.append(current_chunk) | |
| if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens: | |
| long_text = chunks[0] | |
| chunks = [] | |
| while len(tokenizer.tokenize(long_text)) > max_tokens: | |
| split_point = len(long_text) // 2 | |
| while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]): | |
| split_point += 1 | |
| if split_point >= len(long_text): | |
| split_point = len(long_text) - 1 | |
| chunks.append(long_text[:split_point].strip()) | |
| long_text = long_text[split_point:].strip() | |
| if long_text: | |
| chunks.append(long_text) | |
| return chunks | |
| def remove_punctuation(text): | |
| return re.sub(r'[^\w\s]', '', text) | |
| def extract_year(text): | |
| year_match = re.search(r'\b(\d{4})\b', text) | |
| return year_match.group(1) if year_match else None | |
| def create_bibtex_entry(data): | |
| if 'journal' in data: | |
| entry_type = 'article' | |
| elif 'booktitle' in data: | |
| entry_type = 'chapter' | |
| else: | |
| entry_type = 'book' | |
| none_content = data.pop('none', '') | |
| year = extract_year(none_content) | |
| if year and 'year' not in data: | |
| data['year'] = year | |
| author_words = data.get('author', '').split() | |
| first_author = author_words[0] if author_words else 'Unknown' | |
| bibtex_id = f"{first_author}{year}" if year else first_author | |
| bibtex_id = remove_punctuation(bibtex_id.lower()) | |
| bibtex = f"@{entry_type}{{{bibtex_id},\n" | |
| for key, value in data.items(): | |
| if value.strip(): | |
| if key in ['volume', 'year']: | |
| value = remove_punctuation(value) | |
| if key == 'pages': | |
| value = value.replace('p. ', '') | |
| bibtex += f" {key.lower()} = {{{value.strip()}}},\n" | |
| bibtex = bibtex.rstrip(',\n') + "\n}" | |
| return bibtex | |
| class CombinedProcessor: | |
| def process(self, user_message): | |
| editorial_text = re.sub("\n", " ¶ ", user_message) | |
| num_tokens = len(tokenizer.tokenize(editorial_text)) | |
| batch_prompts = split_text(editorial_text, max_tokens=500) if num_tokens > 500 else [editorial_text] | |
| editorial_out = editorial_classifier(batch_prompts) | |
| editorial_df = pd.concat([pd.DataFrame(classification) for classification in editorial_out]) | |
| # Filter out only bibliography entries | |
| bibliography_entries = editorial_df[editorial_df['entity_group'] == 'bibliography']['word'].tolist() | |
| bibtex_entries = [] | |
| for entry in bibliography_entries: | |
| bib_out = bibliography_classifier(entry) | |
| bib_df = pd.DataFrame(bib_out) | |
| bibtex_data = {} | |
| current_entity = None | |
| for _, row in bib_df.iterrows(): | |
| entity_group = row['entity_group'] | |
| word = row['word'] | |
| if entity_group != 'None': | |
| if entity_group in bibtex_data: | |
| bibtex_data[entity_group] += ' ' + word | |
| else: | |
| bibtex_data[entity_group] = word | |
| current_entity = entity_group | |
| else: | |
| if current_entity: | |
| bibtex_data[current_entity] += ' ' + word | |
| else: | |
| bibtex_data['None'] = bibtex_data.get('None', '') + ' ' + word | |
| bibtex_entry = create_bibtex_entry(bibtex_data) | |
| bibtex_entries.append(bibtex_entry) | |
| return bibtex_entries | |
| # Create the processor instance | |
| processor = CombinedProcessor() | |
| # Define the Gradio interface | |
| with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo: | |
| gr.HTML("""<h1 style="text-align:center">Combined Editorial and Bibliography Processor</h1>""") | |
| text_input = gr.Textbox(label="Your text", type="text", lines=10) | |
| text_button = gr.Button("Process Text") | |
| bibtex_output = gr.Textbox(label="BibTeX Entries", lines=15) | |
| text_button.click(processor.process, inputs=text_input, outputs=[bibtex_output]) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |