query / app.py
lkjjj26's picture
update app.py
e0471a2
raw
history blame
28.3 kB
from transformers import pipeline
from rcsbsearchapi import TextQuery, AttributeQuery, Query
from rcsbsearchapi.search import Sort, SequenceQuery
import os
from dotenv import load_dotenv
from shiny import App, render, ui, reactive
import pandas as pd
import warnings
import re
from UniprotKB_P_Sequence_RCSB_API_test import ProteinQuery, ProteinSearchEngine
import plotly.graph_objects as go
from shinywidgets import output_widget, render_widget
import requests
import asyncio
warnings.filterwarnings('ignore')
# Load environment variables from .env file
load_dotenv()
# os.environ["TRANSFORMERS_CACHE"] = "./transformers_cache"
# os.makedirs("./transformers_cache", exist_ok=True)
class PDBSearchAssistant:
def __init__(self, model_name="google/flan-t5-large"):
# Set up HuggingFace pipeline with better model
self.pipe = pipeline(
"text2text-generation",
model=model_name,
max_new_tokens=512,
temperature=0.3,
torch_dtype="auto",
device="cpu"
)
self.prompt_template = """
Extract specific search parameters from the protein-related query:
1. Protein name or type
2. Resolution cutoff (in ร…)
3. Sequence information
4. Specific PDB ID
5. Experimental method (X-RAY, EM, NMR)
6. Organism/Species information
Format:
Protein: [protein name or type]
Organism: [organism/species if mentioned]
Resolution: [maximum resolution in ร…, if mentioned]
Sequence: [any sequence mentioned]
PDB_ID: [specific PDB ID if mentioned]
Method: [experimental method if mentioned]
Examples:
Query: "Find human insulin structures with X-ray better than 2.5ร… resolution"
Protein: insulin
Organism: human
Resolution: 2.5
Sequence: none
PDB_ID: none
Method: X-RAY
Now analyze:
Query: {query}
"""
def search_pdb(self, query):
try:
# Get search parameters from LLM
formatted_prompt = self.prompt_template.format(query=query)
response = self.pipe(formatted_prompt)[0]['generated_text']
print("Generated parameters:", response)
# Parse LLM response
resolution_limit = None
pdb_id = None
sequence = None
method = None
has_resolution_query = False
resolution_direction = "less"
# Check if query contains resolution-related terms
resolution_terms = {
'better': 'less',
'best': 'less',
'highest': 'less',
'good': 'less',
'fine': 'less',
'worse': 'greater',
'worst': 'greater',
'lowest': 'greater',
'poor': 'greater',
'resolution': None,
'รฅ': None,
'angstrom': None,
'than': None,
'under': 'less',
'below': 'less',
'above': 'greater',
'over': 'greater'
}
# Check if the original query mentions resolution
query_lower = query.lower()
# Determine resolution direction from query
for term, direction in resolution_terms.items():
if term in query_lower:
has_resolution_query = True
if direction: # if not None
resolution_direction = direction
# Also check for numerical values with ร…
if re.search(r'\d+\.?\d*\s*รฅ?', query_lower):
has_resolution_query = True
# Clean and parse LLM response
for line in response.split('\n'):
if 'Resolution:' in line:
value = line.split('Resolution:')[1].strip()
if value.lower() not in ['none', 'n/a'] and has_resolution_query:
try:
# Extract just the number
res_value = ''.join(c for c in value if c.isdigit() or c == '.')
resolution_limit = float(res_value)
except ValueError:
pass
elif 'Method:' in line:
value = line.split('Method:')[1].strip()
if value.lower() not in ['none', 'n/a']:
method = value.upper()
elif 'Sequence:' in line:
value = line.split('Sequence:')[1].strip()
if value.lower() not in ['none', 'n/a']:
sequence = value
elif 'PDB_ID:' in line:
value = line.split('PDB_ID:')[1].strip()
if value.lower() not in ['none', 'n/a']:
pdb_id = value
# Build search query
queries = []
# Check if the query contains a protein sequence pattern
# Check for amino acid sequence (minimum 25 residues)
query_words = query.split()
for word in query_words:
# Check if the word consists of valid amino acid letters
if (len(word) >= 25 and # minimum 25 residues requirement
all(c in 'ACDEFGHIKLMNPQRSTVWY' for c in word.upper()) and
sum(c.isupper() for c in word) / len(word) > 0.8):
sequence = word
break
# If sequence is found, use SequenceQuery
if sequence:
if len(sequence) < 25:
print("Warning: Sequence must be at least 25 residues long. Skipping sequence search.")
sequence = None
else:
print(f"Adding sequence search with identity 100% for sequence: {sequence}")
sequence_query = SequenceQuery(
sequence,
identity_cutoff=1.0, # 100% identity
evalue_cutoff=1,
sequence_type="protein"
)
queries.append(sequence_query)
# If no sequence, proceed with text search
else:
# Clean the original query and add text search
clean_query = query.lower()
# Remove resolution numbers and terms if they exist
if has_resolution_query:
clean_query = re.sub(r'\d+\.?\d*\s*รฅ?', '', clean_query)
for term in resolution_terms:
clean_query = clean_query.replace(term, '')
# Clean up extra spaces and trim
clean_query = ' '.join(clean_query.split())
print("Cleaned query:", clean_query)
# Add text search if query is not empty
if clean_query.strip():
text_query = AttributeQuery(
attribute="struct.title",
operator="contains_phrase",
value=clean_query
)
queries.append(text_query)
# Add resolution filter if specified
if resolution_limit and has_resolution_query:
operator = "less_or_equal" if resolution_direction == "less" else "greater_or_equal"
print(f"Adding resolution filter: {operator} {resolution_limit}ร…")
resolution_query = AttributeQuery(
attribute="rcsb_entry_info.resolution_combined",
operator=operator,
value=resolution_limit
)
queries.append(resolution_query)
# Add PDB ID search if specified
if pdb_id:
print(f"Searching for specific PDB ID: {pdb_id}")
id_query = AttributeQuery(
attribute="rcsb_id",
operator="exact_match",
value=pdb_id.upper()
)
queries = [id_query] # Override other queries for direct PDB ID search
# Add experimental method filter if specified
if method:
print(f"Adding experimental method filter: {method}")
method_query = AttributeQuery(
attribute="exptl.method",
operator="exact_match",
value=method
)
queries.append(method_query)
# Combine queries with AND operator
if queries:
final_query = queries[0]
for q in queries[1:]:
final_query = final_query & q
print("Final query:", final_query)
# Execute search
session = final_query.exec()
results = []
# Process results with additional information
search_engine = ProteinSearchEngine()
try:
for entry in session:
try:
# PDB ID ์ถ”์ถœ ๋ฐฉ์‹ ๊ฐœ์„ 
if isinstance(entry, dict):
pdb_id = entry.get('identifier')
elif hasattr(entry, 'identifier'):
pdb_id = entry.identifier
else:
pdb_id = str(entry)
pdb_id = pdb_id.upper() # PDB ID๋Š” ํ•ญ์ƒ ๋Œ€๋ฌธ์ž
if not pdb_id or len(pdb_id) != 4: # PDB ID๋Š” ํ•ญ์ƒ 4์ž๋ฆฌ
continue
# RCSB PDB REST API๋ฅผ ์ง์ ‘ ์‚ฌ์šฉํ•˜์—ฌ ๊ตฌ์กฐ ์ •๋ณด ๊ฐ€์ ธ์˜ค๊ธฐ
structure_url = f"https://data.rcsb.org/rest/v1/core/entry/{pdb_id}"
response = requests.get(structure_url)
if response.status_code != 200:
continue
structure_data = response.json()
# ๊ฒฐ๊ณผ ๊ตฌ์„ฑ
result = {
'PDB ID': pdb_id,
'Resolution': f"{structure_data.get('rcsb_entry_info', {}).get('resolution_combined', [0.0])[0]:.2f}ร…",
'Method': structure_data.get('exptl', [{}])[0].get('method', 'Unknown'),
'Title': structure_data.get('struct', {}).get('title', 'N/A'),
'Release Date': structure_data.get('rcsb_accession_info', {}).get('initial_release_date', 'N/A')
}
results.append(result)
# Limit to top 10 results
if len(results) >= 10:
break
except Exception as e:
print(f"Error processing entry: {str(e)}")
continue
except Exception as e:
print(f"Error processing results: {str(e)}")
print(f"Error type: {type(e)}")
print(f"Found {len(results)} structures")
return results
return []
except Exception as e:
print(f"Error during search: {str(e)}")
print(f"Error type: {type(e)}")
return []
def get_sequences_by_pdb_id(self, pdb_id):
"""Get sequences for all chains in a PDB structure"""
try:
# ProteinSearchEngine ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
search_engine = ProteinSearchEngine()
# ProteinQuery ๊ฐ์ฒด ์ƒ์„ฑ (resolution limit์€ ๋†’๊ฒŒ ์„ค์ •ํ•˜์—ฌ ๋ชจ๋“  ๊ฒฐ๊ณผ ํฌํ•จ)
query = ProteinQuery(
name=pdb_id,
max_resolution=100.0 # ๋†’์€ ๊ฐ’์œผ๋กœ ์„ค์ •ํ•˜์—ฌ ๋ชจ๋“  ๊ตฌ์กฐ ํฌํ•จ
)
# ๊ฒ€์ƒ‰ ์‹คํ–‰
results = search_engine.search(query)
if not results:
return []
sequences = []
# ๊ฒฐ๊ณผ์—์„œ sequence ์ •๋ณด ์ถ”์ถœ
for structure in results:
if structure.pdb_id.upper() == pdb_id.upper():
chain_info = {
'chain_id': 'ALL', # ์ฒด์ธ ์ •๋ณด๋Š” ํ†ตํ•ฉ
'entity_id': '1',
'description': structure.title,
'sequence': structure.sequence,
'length': len(structure.sequence),
'resolution': structure.resolution,
'method': structure.method,
'release_date': structure.release_date
}
sequences.append(chain_info)
break # ์ •ํ™•ํ•œ PDB ID ๋งค์น˜๋ฅผ ์ฐพ์œผ๋ฉด ์ค‘๋‹จ
# ๊ฒฐ๊ณผ๊ฐ€ ์—†์œผ๋ฉด ์ง์ ‘ API ํ˜ธ์ถœ ์‹œ๋„
if not sequences:
print(f"No results found using ProteinSearchEngine, trying direct API call...")
return self._get_sequences_by_direct_api(pdb_id)
return sequences
except Exception as e:
print(f"Error in ProteinSearchEngine search for PDB ID {pdb_id}: {str(e)}")
# ์—๋Ÿฌ ๋ฐœ์ƒ ์‹œ ์ง์ ‘ API ํ˜ธ์ถœ๋กœ ํด๋ฐฑ
return self._get_sequences_by_direct_api(pdb_id)
def _get_sequences_by_direct_api(self, pdb_id):
"""Fallback method using direct API calls"""
# ๊ธฐ์กด์˜ get_sequences_by_pdb_id ๋ฉ”์†Œ๋“œ ๋‚ด์šฉ์„ ์—ฌ๊ธฐ๋กœ ์ด๋™
try:
url = f"https://data.rcsb.org/rest/v1/core/polymer_entity_instances/{pdb_id}"
response = requests.get(url)
if response.status_code != 200:
return []
chains_data = response.json()
sequences = []
for chain_id in chains_data.keys():
entity_id = chains_data[chain_id].get('rcsb_polymer_entity_instance_container_identifiers', {}).get('entity_id')
if entity_id:
entity_url = f"https://data.rcsb.org/rest/v1/core/polymer_entity/{pdb_id}/{entity_id}"
entity_response = requests.get(entity_url)
if entity_response.status_code == 200:
entity_data = entity_response.json()
sequence = entity_data.get('entity_poly', {}).get('pdbx_seq_one_letter_code', '')
description = entity_data.get('rcsb_polymer_entity', {}).get('pdbx_description', 'N/A')
chain_info = {
'chain_id': chain_id,
'entity_id': entity_id,
'description': description,
'sequence': sequence,
'length': len(sequence)
}
sequences.append(chain_info)
return sequences
except Exception as e:
print(f"Error in direct API call for PDB ID {pdb_id}: {str(e)}")
return []
def analyze_query_type(self, query):
"""Analyze query type and extract relevant information"""
print(f"\nAnalyzing query: '{query}'") # ์ž…๋ ฅ๋œ ์ฟผ๋ฆฌ ์ถœ๋ ฅ
query = query.lower().strip()
print(f"Lowercase query: '{query}'") # ์†Œ๋ฌธ์ž๋กœ ๋ณ€ํ™˜๋œ ์ฟผ๋ฆฌ ์ถœ๋ ฅ
# Check for sequence query pattern
sequence_patterns = [
r"sequence\s+of\s+pdb\s+id\s+([a-zA-Z0-9]{4})",
r"sequence\s+for\s+pdb\s+id\s+([a-zA-Z0-9]{4})",
r"get\s+sequence\s+([a-zA-Z0-9]{4})",
r"([a-zA-Z0-9]{4})\s+sequence"
]
for i, pattern in enumerate(sequence_patterns):
print(f"Trying pattern {i+1}: {pattern}") # ๊ฐ ํŒจํ„ด ์‹œ๋„ ์ถœ๋ ฅ
match = re.search(pattern, query)
if match:
pdb_id = match.group(1).upper()
print(f"Match found! PDB ID: {pdb_id}") # ๋งค์น˜๋œ PDB ID ์ถœ๋ ฅ
return {
"type": "sequence",
"pdb_id": pdb_id
}
print("No sequence pattern matched, treating as structure search") # ๊ตฌ์กฐ ๊ฒ€์ƒ‰์œผ๋กœ ์ฒ˜๋ฆฌ
return {
"type": "structure",
"query": query
}
def process_query(self, query):
"""Process query and return appropriate results"""
query_info = self.analyze_query_type(query)
if query_info["type"] == "sequence":
return {
"type": "sequence",
"results": self.get_sequences_by_pdb_id(query_info["pdb_id"])
}
else:
return {
"type": "structure",
"results": self.search_pdb(query_info["query"])
}
def create_interactive_table(df):
if df.empty:
return go.Figure()
# Reorder columns
column_order = ['PDB ID', 'Resolution', 'Method', 'Title', 'Release Date']
df = df[column_order]
# Release Date ํ˜•์‹ ๋ณ€๊ฒฝ (YYYY-MM-DD)
df['Release Date'] = pd.to_datetime(df['Release Date']).dt.strftime('%Y-%m-%d')
# Create interactive table
table = go.Figure(data=[go.Table(
header=dict(
values=list(df.columns),
fill_color='paleturquoise',
align='center', # ํ—ค๋” ์ค‘์•™ ์ •๋ ฌ
font=dict(size=16), # ํ—ค๋” ๊ธ€์ž ํฌ๊ธฐ ์ฆ๊ฐ€
),
cells=dict(
values=[
[f'<a href="https://www.rcsb.org/structure/{cell}">{cell}</a>'
if i == 0 else cell
for cell in df[col]]
for i, col in enumerate(df.columns)
],
align='center', # ์…€ ๋‚ด์šฉ ์ค‘์•™ ์ •๋ ฌ
font=dict(size=15), # ์…€ ๊ธ€์ž ํฌ๊ธฐ ์ฆ๊ฐ€
height=35 # ์…€ ๋†’์ด ์ฆ๊ฐ€
),
columnwidth=[80, 80, 100, 400, 100],
customdata=[['html'] * len(df) if i == 0 else [''] * len(df)
for i in range(len(df.columns))],
hoverlabel=dict(bgcolor='white')
)])
# Update table layout
table.update_layout(
margin=dict(l=20, r=20, t=20, b=20),
height=450, # ํ…Œ์ด๋ธ” ์ „์ฒด ๋†’์ด ์ฆ๊ฐ€
autosize=True
)
return table
# Simplified Shiny app UI definition
app_ui = ui.page_fluid(
ui.tags.head(
ui.tags.style("""
.container-fluid {
max-width: 1200px;
margin: 0 auto;
padding: 20px;
}
.table a {
color: #0d6efd;
text-decoration: none;
}
.table a:hover {
color: #0a58ca;
text-decoration: underline;
}
.shiny-input-container {
max-width: 100%;
margin: 0 auto;
}
#query {
height: 100px;
font-size: 16px;
padding: 15px;
width: 80%;
margin: 0 auto;
display: block;
}
.content-wrapper {
text-align: center;
max-width: 1000px;
margin: 0 auto;
}
.search-button {
margin: 20px 0;
}
h2, h4 {
text-align: center;
margin: 20px 0;
}
.example-box {
background-color: #f8f9fa;
border-radius: 8px;
padding: 20px;
margin: 20px auto;
width: 80%;
text-align: left;
}
.example-box p {
font-weight: bold;
margin-bottom: 10px;
padding-left: 20px;
}
.example-box ul {
margin: 0;
padding-left: 40px;
}
.example-box li {
word-wrap: break-word;
margin: 10px 0;
line-height: 1.5;
}
.query-label {
display: block;
text-align: left;
margin-bottom: 10px;
margin-left: 10%;
font-weight: bold;
}
.status-box {
background-color: #f8f9fa;
border-radius: 8px;
padding: 15px;
margin: 20px auto;
width: 80%;
text-align: left;
}
.status-label {
font-weight: bold;
margin-right: 10px;
}
.status-ready {
color: #198754; /* Bootstrap success color */
font-weight: bold;
}
.sequence-results {
width: 80%;
margin: 20px auto;
text-align: left;
font-family: monospace;
white-space: pre-wrap;
word-wrap: break-word;
background-color: #f8f9fa;
border-radius: 8px;
padding: 20px;
overflow-x: hidden;
}
.sequence-text {
word-break: break-all;
margin: 10px 0;
line-height: 1.5;
}
.status-spinner {
display: none;
margin-left: 10px;
vertical-align: middle;
}
.status-spinner.active {
display: inline-block;
}
""")
),
ui.div(
{"class": "content-wrapper"},
ui.h2("Advanced PDB Structure Search Tool"),
ui.row(
ui.column(12,
ui.tags.label(
"Search Query",
{"class": "query-label", "for": "query"}
),
ui.input_text(
"query",
"",
value="Human insulin",
width="100%"
),
)
),
ui.row(
ui.column(12,
ui.div(
{"class": "example-box"},
ui.p("Example queries:"),
ui.tags.ul(
ui.tags.li("Human hemoglobin C resolution better than 2.5ร…"),
ui.tags.li("Find structures containing sequence MNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSELDKAIGRNTNGVITKDEAEKLFNQDVDAAVRGILRNAKLKPVYDSLDAVRRAALINMVFQMGETGVAGFTNSLRMLQQKRWDEAAVNLAKSRWYNQTPNRAKRVITTFRTGTWDAYKNL"),
ui.tags.li("Sequence of PDB ID 8ET6"),
ui.tags.li("Get sequence 7BZ5")
)
)
)
),
ui.row(
ui.column(12,
ui.div(
{"class": "search-button"},
ui.input_action_button("search", "Search",
class_="btn-primary btn-lg") # ๋ฒ„ํŠผ ํฌ๊ธฐ ์ฆ๊ฐ€
)
)
),
ui.row(
ui.column(12,
ui.h4("Search Parameters:"),
ui.div(
{"class": "status-box"},
ui.tags.span("Status: ", class_="status-label"),
ui.output_text("search_status", inline=True),
ui.tags.div(
{"class": "status-spinner"},
ui.tags.i({"class": "fas fa-spinner fa-spin"})
)
)
)
),
ui.row(
ui.column(12,
ui.h4("Top 10 Results:"),
output_widget("results_table"),
ui.download_button("download", "Download Results",
class_="btn btn-info btn-lg") # ๋‹ค์šด๋กœ๋“œ ๋ฒ„ํŠผ ์Šคํƒ€์ผ ๊ฐœ์„ 
)
),
ui.row(
ui.column(12,
ui.div(
{"class": "sequence-results", "id": "sequence-results"},
ui.h4("Sequences:"),
ui.output_text("sequence_output")
)
)
)
)
)
def server(input, output, session):
assistant = PDBSearchAssistant()
results_store = reactive.Value({"type": None, "results": []})
status_store = reactive.Value("Ready")
@reactive.Effect
@reactive.event(input.search)
def _():
# ๊ฒ€์ƒ‰ ์‹œ์ž‘ ์‹œ ์ƒํƒœ ๋ณ€๊ฒฝ
status_store.set("Searching...")
# ํ”„๋กฌํ”„ํŠธ ์ฒ˜๋ฆฌ
query_results = assistant.process_query(input.query())
results_store.set(query_results)
if query_results["type"] == "sequence":
if not query_results["results"]:
status_store.set("No sequences found")
else:
status_store.set("Ready") # ๊ฒ€์ƒ‰ ์™„๋ฃŒ ์‹œ Ready๋กœ ๋ณ€๊ฒฝ
else:
df = pd.DataFrame(query_results["results"])
status_store.set("Ready") # ๊ฒ€์ƒ‰ ์™„๋ฃŒ ์‹œ Ready๋กœ ๋ณ€๊ฒฝ
@output
@render_widget
def results_table():
return create_interactive_table(df)
@output
@render.text
def search_status():
return status_store.get()
@output
@render.download(filename="pdb_search_results.csv")
def download():
current_results = results_store.get()
if current_results["type"] == "structure":
df = pd.DataFrame(current_results["results"])
else:
df = pd.DataFrame(current_results["results"])
return df.to_csv(index=False)
@output
@render.text
def sequence_output():
current_results = results_store.get()
print(current_results["type"])
print(current_results["results"])
if current_results["type"] == "sequence":
sequences = current_results["results"]
if not sequences:
return "No sequences found"
output_text = []
for seq in sequences:
output_text.append(f"\nChain {seq['chain_id']} (Entity {seq['entity_id']}):")
output_text.append(f"Description: {seq['description']}")
output_text.append(f"Length: {seq['length']} residues")
output_text.append("Sequence:")
# ์‹œํ€€์Šค๋ฅผ 60๊ธ€์ž์”ฉ ๋‚˜๋ˆ„์–ด ์ค„๋ฐ”๊ฟˆ
sequence = seq['sequence']
formatted_sequence = '\n'.join([sequence[i:i+60] for i in range(0, len(sequence), 60)])
output_text.append(formatted_sequence)
output_text.append("-" * 60) # ๊ตฌ๋ถ„์„  ๊ธธ์ด๋„ ์กฐ์ •
return "\n".join(output_text)
return ""
app = App(app_ui, server)
if __name__ == "__main__":
import nest_asyncio
nest_asyncio.apply()
app.run(host="0.0.0.0", port=7862)