from transformers import pipeline |
from rcsbsearchapi import TextQuery, AttributeQuery, Query |
from rcsbsearchapi.search import Sort, SequenceQuery |
import os |
from dotenv import load_dotenv |
from shiny import App, render, ui, reactive |
import pandas as pd |
import warnings |
import re |
from UniprotKB_P_Sequence_RCSB_API_test import ProteinQuery, ProteinSearchEngine |
import plotly.graph_objects as go |
from shinywidgets import output_widget, render_widget |
warnings.filterwarnings('ignore') |
load_dotenv() |
cache_dir = "./transformers_cache" |
os.makedirs(cache_dir, exist_ok=True) |
os.environ["TRANSFORMERS_CACHE"] = cache_dir |
print("Hugging Face Cache Path:", os.getenv("TRANSFORMERS_CACHE")) |
class PDBSearchAssistant: |
def __init__(self, model_name="google/flan-t5-large"): |
self.pipe = pipeline( |
"text2text-generation", |
model=model_name, |
max_new_tokens=512, |
temperature=0.3, |
torch_dtype="auto", |
device="cpu" |
) |
self.prompt_template = """ |
Extract specific search parameters from the query, if present: |
1. Resolution cutoff (in ร
) |
2. Sequence information |
3. Specific PDB ID |
4. Experimental method (X-RAY, EM, NMR) |
Format: |
Resolution: [maximum resolution in ร
, if mentioned] |
Sequence: [any sequence mentioned] |
PDB_ID: [specific PDB ID if mentioned] |
Method: [experimental method if mentioned] |
Examples: |
Query: "Find X-ray structures better than 2.5ร
resolution" |
Resolution: 2.5 |
Sequence: none |
PDB_ID: none |
Method: X-RAY |
Query: "Show me NMR structures of kinases" |
Resolution: none |
Sequence: none |
PDB_ID: none |
Method: NMR |
Now analyze: |
Query: {query} |
""" |
def search_pdb(self, query): |
try: |
formatted_prompt = self.prompt_template.format(query=query) |
response = self.pipe(formatted_prompt)[0]['generated_text'] |
print("Generated parameters:", response) |
resolution_limit = None |
pdb_id = None |
sequence = None |
method = None |
has_resolution_query = False |
resolution_direction = "less" |
resolution_terms = { |
'better': 'less', |
'best': 'less', |
'highest': 'less', |
'good': 'less', |
'fine': 'less', |
'worse': 'greater', |
'worst': 'greater', |
'lowest': 'greater', |
'poor': 'greater', |
'resolution': None, |
'รฅ': None, |
'angstrom': None, |
'than': None, |
'under': 'less', |
'below': 'less', |
'above': 'greater', |
'over': 'greater' |
} |
query_lower = query.lower() |
for term, direction in resolution_terms.items(): |
if term in query_lower: |
has_resolution_query = True |
if direction: |
resolution_direction = direction |
if re.search(r'\d+\.?\d*\s*รฅ?', query_lower): |
has_resolution_query = True |
for line in response.split('\n'): |
if 'Resolution:' in line: |
value = line.split('Resolution:')[1].strip() |
if value.lower() not in ['none', 'n/a'] and has_resolution_query: |
try: |
res_value = ''.join(c for c in value if c.isdigit() or c == '.') |
resolution_limit = float(res_value) |
except ValueError: |
pass |
elif 'Method:' in line: |
value = line.split('Method:')[1].strip() |
if value.lower() not in ['none', 'n/a']: |
method = value.upper() |
elif 'Sequence:' in line: |
value = line.split('Sequence:')[1].strip() |
if value.lower() not in ['none', 'n/a']: |
sequence = value |
elif 'PDB_ID:' in line: |
value = line.split('PDB_ID:')[1].strip() |
if value.lower() not in ['none', 'n/a']: |
pdb_id = value |
queries = [] |
query_words = query.split() |
for word in query_words: |
if (len(word) >= 25 and |
all(c in 'ACDEFGHIKLMNPQRSTVWY' for c in word.upper()) and |
sum(c.isupper() for c in word) / len(word) > 0.8): |
sequence = word |
break |
if sequence: |
if len(sequence) < 25: |
print("Warning: Sequence must be at least 25 residues long. Skipping sequence search.") |
sequence = None |
else: |
print(f"Adding sequence search with identity 100% for sequence: {sequence}") |
sequence_query = SequenceQuery( |
sequence, |
identity_cutoff=1.0, |
evalue_cutoff=1, |
sequence_type="protein" |
) |
queries.append(sequence_query) |
else: |
clean_query = query.lower() |
if has_resolution_query: |
clean_query = re.sub(r'\d+\.?\d*\s*รฅ?', '', clean_query) |
for term in resolution_terms: |
clean_query = clean_query.replace(term, '') |
clean_query = ' '.join(clean_query.split()) |
print("Cleaned query:", clean_query) |
if clean_query.strip(): |
text_query = AttributeQuery( |
attribute="struct.title", |
operator="contains_phrase", |
value=clean_query |
) |
queries.append(text_query) |
if resolution_limit and has_resolution_query: |
operator = "less_or_equal" if resolution_direction == "less" else "greater_or_equal" |
print(f"Adding resolution filter: {operator} {resolution_limit}ร
") |
resolution_query = AttributeQuery( |
attribute="rcsb_entry_info.resolution_combined", |
operator=operator, |
value=resolution_limit |
) |
queries.append(resolution_query) |
if pdb_id: |
print(f"Searching for specific PDB ID: {pdb_id}") |
id_query = AttributeQuery( |
attribute="rcsb_id", |
operator="exact_match", |
value=pdb_id.upper() |
) |
queries = [id_query] |
if method: |
print(f"Adding experimental method filter: {method}") |
method_query = AttributeQuery( |
attribute="exptl.method", |
operator="exact_match", |
value=method |
) |
queries.append(method_query) |
if queries: |
final_query = queries[0] |
for q in queries[1:]: |
final_query = final_query & q |
print("Final query:", final_query) |
session = final_query.exec() |
results = [] |
try: |
for entry in session: |
if isinstance(entry, str): |
result = { |
'PDB ID': entry |
} |
else: |
result = { |
'PDB ID': entry.identifier |
} |
results.append(result) |
except Exception as e: |
print(f"Error processing results: {str(e)}") |
if isinstance(entry, str): |
results.append({'PDB ID': entry}) |
print(f"Found {len(results)} structures") |
return results |
return [] |
except Exception as e: |
print(f"Error during search: {str(e)}") |
print(f"Error type: {type(e)}") |
return [] |
def pdbsummary(name): |
search_engine = ProteinSearchEngine() |
query = ProteinQuery( |
name, |
max_resolution= 5.0 |
) |
results = search_engine.search(query) |
answer = "" |
for i, structure in enumerate(results, 1): |
answer += f"\n{i}. PDB ID : {structure.pdb_id}\n" |
answer += f"\nResolution : {structure.resolution:.2f} A \n" |
answer += f"Method : {structure.method}\n Title : {structure.title}\n" |
answer += f"Release Date : {structure.release_date}\n Sequence length: {len(structure.sequence)} aa\n" |
answer += f" Sequence:\n {structure.sequence}\n" |
return answer |
def create_interactive_table(df): |
if df.empty: |
return go.Figure() |
table = go.Figure(data=[go.Table( |
header=dict( |
values=list(df.columns), |
fill_color='paleturquoise', |
align='left', |
font=dict(size=14), |
), |
cells=dict( |
values=[df[col] for col in df.columns], |
align='left', |
font=dict(size=13), |
height=30 |
), |
columnwidth=[len(str(max(df[col], key=len))) for col in df.columns] |
)]) |
table.update_layout( |
margin=dict(l=0, r=0, t=0, b=0), |
height=400, |
autosize=True |
) |
return table |
app_ui = ui.page_fluid( |
ui.tags.head( |
ui.tags.style(""" |
.table a { |
color: #0d6efd; |
text-decoration: none; |
} |
.table a:hover { |
color: #0a58ca; |
text-decoration: underline; |
} |
""") |
), |
ui.h2("Advanced PDB Structure Search Tool"), |
ui.row( |
ui.column(12, |
ui.input_text("query", "Search Query", |
value="Human insulin"), |
) |
), |
ui.row( |
ui.column(12, |
ui.p("Example queries:"), |
ui.tags.ul( |
ui.tags.li("Human hemoglobin C resolution better than 2.5ร
"), |
), |
) |
), |
ui.row( |
ui.column(12, |
ui.input_action_button("search", "Search", class_="btn-primary"), |
) |
), |
ui.row( |
ui.column(12, |
ui.h4("Search Parameters:"), |
ui.output_text("search_conditions"), |
) |
), |
ui.row( |
ui.column(12, |
ui.h4("Top 10 Results:"), |
output_widget("results_table"), |
ui.download_button("download", "Download Results") |
) |
) |
) |
def server(input, output, session): |
assistant = PDBSearchAssistant() |
results_store = reactive.Value([]) |
@reactive.Effect |
@reactive.event(input.search) |
def _(): |
results = assistant.search_pdb(query=input.query()) |
results_store.set(results) |
df = pd.DataFrame(results) |
if not df.empty: |
df['PDB ID'] = df['PDB ID'].apply( |
lambda x: f'<a href="https://www.rcsb.org/3d-view/{x}" target="_blank">{x}</a>' |
) |
@output |
@render_widget |
def results_table(): |
return create_interactive_table(df) |
@output |
@render.text |
def search_conditions(): |
results = results_store.get() |
return f""" |
Applied Search Conditions: |
- Query: {input.query()} |
- Total structures found: {len(results)} |
""" |
@output |
@render.download(filename="pdb_search_results.csv") |
def download(): |
df = pd.DataFrame(results_store.get()) |
return df.to_csv(index=False) |
app = App(app_ui, server) |
if __name__ == "__main__": |
import nest_asyncio |
nest_asyncio.apply() |
app.run(host="", port=7860) |