import React, { RefObject } from 'react';
import mermaid from 'mermaid';
import Button from '@mui/material/Button';

interface MermaidProps {
  chart: string;
}

interface MermaidState {
  diagramId: string;
}

mermaid.initialize({
  startOnLoad: true,
  theme: 'neutral',
});

class Mermaid extends React.Component<MermaidProps, MermaidState> {
  mermaidRef: RefObject<HTMLDivElement>;

  constructor(props: MermaidProps) {
    super(props);
    this.state = {
      diagramId: `mermaid-${Math.random().toString(36).substr(2, 9)}`,
    };
    this.mermaidRef = React.createRef<HTMLDivElement>();
  }

  componentDidMount() {
    this.renderMermaid();
  }

  componentDidUpdate(prevProps: MermaidProps) {
    if (prevProps.chart !== this.props.chart) {
      const element = document.getElementById(this.state.diagramId);
      if (element) {
        element.removeAttribute('data-processed');
      }
      this.renderMermaid();
    }
  }

  renderMermaid() {
    mermaid.contentLoaded();
  }

  downloadSVG = () => {
    const svg = this.mermaidRef.current?.querySelector('svg');
    if (!svg) {
      console.error('No SVG found to download');
      return;
    }

    const serializer = new XMLSerializer();
    let source = serializer.serializeToString(svg);
    // Add necessary namespaces
    if (!source.match(/^<svg[^>]+xmlns="http:\/\/www\.w3\.org\/2000\/svg"/)) {
      source = source.replace(/^<svg/, '<svg xmlns="http://www.w3.org/2000/svg"');
    }
    if (!source.match(/^<svg[^>]+"http:\/\/www\.w3\.org\/1999\/xlink"/)) {
      source = source.replace(/^<svg/, '<svg xmlns:xlink="http://www.w3.org/1999/xlink"');
    }
    // Add XML declaration
    source = '<?xml version="1.0" standalone="no"?>\r\n' + source;
    const url = 'data:image/svg+xml;charset=utf-8,' + encodeURIComponent(source);
    this.triggerDownload(url, 'svg');
  }

  downloadPNGorJPG = (format: 'png' | 'jpeg') => {
    const svg = this.mermaidRef.current?.querySelector('svg');
    if (!svg) {
      console.error('No SVG found to convert');
      return;
    }

    const serializer = new XMLSerializer();
    const source = serializer.serializeToString(svg);
    const img = new Image();
    img.crossOrigin = 'anonymous'; // Set crossOrigin to anonymous before setting src
    const svgBlob = new Blob([source], { type: 'image/svg+xml;charset=utf-8' });
    const url = URL.createObjectURL(svgBlob);

    img.onload = () => {
      const canvas = document.createElement('canvas');
      canvas.width = img.naturalWidth;
      canvas.height = img.naturalHeight;
      const ctx = canvas.getContext('2d');
      if (ctx) {
        ctx.drawImage(img, 0, 0);
        URL.revokeObjectURL(url);

        canvas.toBlob((blob) => {
          if (blob) {
            const imgUrl = URL.createObjectURL(blob);
            this.triggerDownload(imgUrl, format);
          }
        }, `image/${format}`);
      }
    };
    img.src = url;
  }

  triggerDownload = (url: string, format: 'svg' | 'png' | 'jpeg') => {
    const link = document.createElement('a');
    link.href = url;
    link.download = `${this.state.diagramId}.${format}`;
    document.body.appendChild(link);
    link.click();
    document.body.removeChild(link);
  }

  render() {
    return (
      <>
        <div
          ref={this.mermaidRef}
          id={this.state.diagramId}
          className="mermaid"
          style={{
            maxHeight: '30%',
            overflow: 'auto',
            margin: '0 auto',
            display: 'block',
            width: '80%',
          }}
        >
          {this.props.chart}
        </div>
        <div>
          <Button
            variant="outlined"
            color="secondary"
            onClick={this.downloadSVG}
            style={{ margin: '10px' }}
          >
            Download SVG
          </Button>
          <Button
            variant="outlined"
            color="secondary"
            onClick={() => this.downloadPNGorJPG('png')}
            style={{ margin: '10px' }}
          >
            Download PNG
          </Button>
          <Button
            variant="outlined"
            color="secondary"
            onClick={() => this.downloadPNGorJPG('jpeg')}
            style={{ margin: '10px' }}
          >
            Download JPG
          </Button>
        </div>
      </>
    );
  }
}

export default React.memo(Mermaid);
